Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -247,16 +247,18 @@ the wall, slowly on into the Social Predestination Room.
|
||||
as they entered."""
|
||||
|
||||
|
||||
def create_setup_and_compute(model_names: List[str],
|
||||
gpu: bool = True,
|
||||
tensorflow: bool = False,
|
||||
average_over: int = 3,
|
||||
torchscript: bool = False,
|
||||
xla: bool = False,
|
||||
amp: bool = False,
|
||||
fp16: bool = False,
|
||||
save_to_csv: bool = False,
|
||||
csv_filename: str = f"results_{round(time())}.csv"):
|
||||
def create_setup_and_compute(
|
||||
model_names: List[str],
|
||||
gpu: bool = True,
|
||||
tensorflow: bool = False,
|
||||
average_over: int = 3,
|
||||
torchscript: bool = False,
|
||||
xla: bool = False,
|
||||
amp: bool = False,
|
||||
fp16: bool = False,
|
||||
save_to_csv: bool = False,
|
||||
csv_filename: str = f"results_{round(time())}.csv",
|
||||
):
|
||||
if xla:
|
||||
tf.config.optimizer.set_jit(True)
|
||||
if amp:
|
||||
@@ -266,7 +268,7 @@ def create_setup_and_compute(model_names: List[str],
|
||||
dictionary = {model_name: {} for model_name in model_names}
|
||||
results = _compute_tensorflow(model_names, dictionary, average_over, amp)
|
||||
else:
|
||||
device = 'cuda' if (gpu and torch.cuda.is_available()) else 'cpu'
|
||||
device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
|
||||
dictionary = {model_name: {} for model_name in model_names}
|
||||
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
|
||||
|
||||
@@ -276,34 +278,52 @@ def create_setup_and_compute(model_names: List[str],
|
||||
for batch_size in results[model_name]["bs"]:
|
||||
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
|
||||
for slice_size in results[model_name]["ss"]:
|
||||
result = results[model_name]['results'][batch_size][slice_size]
|
||||
result = results[model_name]["results"][batch_size][slice_size]
|
||||
if isinstance(result, str):
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: "
|
||||
f"{result}")
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
|
||||
else:
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: "
|
||||
f"{(round(1000 * result) / 1000)}"
|
||||
f"s")
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
|
||||
|
||||
if save_to_csv:
|
||||
with open(csv_filename, mode='w') as csv_file:
|
||||
fieldnames = ['model',
|
||||
'1x8', '1x64', '1x128', '1x256', '1x512', '1x1024',
|
||||
'2x8', '2x64', '2x128', '2x256', '2x512', '2x1024',
|
||||
'4x8', '4x64', '4x128', '4x256', '4x512', '4x1024',
|
||||
'8x8', '8x64', '8x128', '8x256', '8x512', '8x1024',
|
||||
]
|
||||
with open(csv_filename, mode="w") as csv_file:
|
||||
fieldnames = [
|
||||
"model",
|
||||
"1x8",
|
||||
"1x64",
|
||||
"1x128",
|
||||
"1x256",
|
||||
"1x512",
|
||||
"1x1024",
|
||||
"2x8",
|
||||
"2x64",
|
||||
"2x128",
|
||||
"2x256",
|
||||
"2x512",
|
||||
"2x1024",
|
||||
"4x8",
|
||||
"4x64",
|
||||
"4x128",
|
||||
"4x256",
|
||||
"4x512",
|
||||
"4x1024",
|
||||
"8x8",
|
||||
"8x64",
|
||||
"8x128",
|
||||
"8x256",
|
||||
"8x512",
|
||||
"8x1024",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for model_name in model_names:
|
||||
model_results = {
|
||||
f'{bs}x{ss}': results[model_name]['results'][bs][ss]
|
||||
f"{bs}x{ss}": results[model_name]["results"][bs][ss]
|
||||
for bs in results[model_name]["results"]
|
||||
for ss in results[model_name]['results'][bs]
|
||||
for ss in results[model_name]["results"][bs]
|
||||
}
|
||||
writer.writerow({'model': model_name, **model_results})
|
||||
writer.writerow({"model": model_name, **model_results})
|
||||
|
||||
|
||||
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
|
||||
@@ -343,7 +363,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
|
||||
|
||||
print("Going through model with sequence of shape", sequence.shape)
|
||||
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
|
||||
average_time = sum(runtimes)/float(len(runtimes)) / 3.0
|
||||
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = average_time
|
||||
except RuntimeError as e:
|
||||
print("Doesn't fit on GPU.", e)
|
||||
@@ -379,7 +399,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
||||
if max_input_size is not None and slice_size > max_input_size:
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||
else:
|
||||
sequence = tf.stack([tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size)
|
||||
sequence = tf.stack(
|
||||
[tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size
|
||||
)
|
||||
|
||||
try:
|
||||
print("Going through model with sequence of shape", sequence.shape)
|
||||
@@ -387,7 +409,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
||||
inference(sequence)
|
||||
|
||||
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
|
||||
average_time = sum(runtimes)/float(len(runtimes)) / 3.0
|
||||
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = average_time
|
||||
except tf.errors.ResourceExhaustedError as e:
|
||||
print("Doesn't fit on GPU.", e)
|
||||
@@ -399,33 +421,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--models", required=False, type=str, default='all', help="Model checkpoints to be provided "
|
||||
"to the AutoModel classes. Leave "
|
||||
"blank to benchmark the base version "
|
||||
"of all available model "
|
||||
"architectures.")
|
||||
parser.add_argument("--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the "
|
||||
"models")
|
||||
parser.add_argument("--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available "
|
||||
"cuda devices")
|
||||
parser.add_argument("--torchscript", required=False, action="store_true", help="Pytorch only: trace the models "
|
||||
"using torchscript")
|
||||
parser.add_argument("--tensorflow", required=False, action="store_true", help="Benchmark the TensorFlow version "
|
||||
"of the models. Will run on GPU if "
|
||||
"the correct dependencies are "
|
||||
"installed")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
required=False,
|
||||
type=str,
|
||||
default="all",
|
||||
help="Model checkpoints to be provided "
|
||||
"to the AutoModel classes. Leave "
|
||||
"blank to benchmark the base version "
|
||||
"of all available model "
|
||||
"architectures.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "cuda devices"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torchscript",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Pytorch only: trace the models " "using torchscript",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensorflow",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Benchmark the TensorFlow version "
|
||||
"of the models. Will run on GPU if "
|
||||
"the correct dependencies are "
|
||||
"installed",
|
||||
)
|
||||
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
|
||||
parser.add_argument("--amp", required=False, action="store_true", help="TensorFlow only: use automatic mixed precision acceleration.")
|
||||
parser.add_argument("--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference.")
|
||||
parser.add_argument("--keras_predict", required=False, action="store_true", help="Whether to use model.predict "
|
||||
"instead of model() to do a "
|
||||
"forward pass.")
|
||||
parser.add_argument(
|
||||
"--amp",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="TensorFlow only: use automatic mixed precision acceleration.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keras_predict",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Whether to use model.predict " "instead of model() to do a " "forward pass.",
|
||||
)
|
||||
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
|
||||
parser.add_argument("--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv.")
|
||||
parser.add_argument("--average_over", required=False, default=30, type=int, help="Times an experiment will be run.")
|
||||
parser.add_argument(
|
||||
"--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.models == 'all':
|
||||
if args.models == "all":
|
||||
args.models = [
|
||||
"gpt2",
|
||||
"bert-base-cased",
|
||||
@@ -436,7 +489,7 @@ def main():
|
||||
"distilbert-base-uncased",
|
||||
"distilgpt2",
|
||||
"roberta-base",
|
||||
"ctrl"
|
||||
"ctrl",
|
||||
]
|
||||
else:
|
||||
args.models = args.models.split()
|
||||
@@ -453,7 +506,7 @@ def main():
|
||||
fp16=args.fp16,
|
||||
save_to_csv=args.save_to_csv,
|
||||
csv_filename=args.csv_filename,
|
||||
average_over=args.average_over
|
||||
average_over=args.average_over,
|
||||
)
|
||||
else:
|
||||
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
|
||||
@@ -467,11 +520,11 @@ def main():
|
||||
amp=args.amp,
|
||||
save_to_csv=args.save_to_csv,
|
||||
csv_filename=args.csv_filename,
|
||||
average_over=args.average_over
|
||||
average_over=args.average_over,
|
||||
)
|
||||
else:
|
||||
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -10,38 +10,37 @@ from transformers.modeling_camembert import CamembertForMaskedLM
|
||||
|
||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
|
||||
assert masked_input.count('<mask>') == 1
|
||||
assert masked_input.count("<mask>") == 1
|
||||
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
|
||||
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
|
||||
logits = logits[0, masked_index, :]
|
||||
prob = logits.softmax(dim=0)
|
||||
values, indices = prob.topk(k=topk, dim=0)
|
||||
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item())
|
||||
for i in range(len(indices))])
|
||||
topk_predicted_token_bpe = " ".join(
|
||||
[tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
|
||||
)
|
||||
masked_token = tokenizer.mask_token
|
||||
topk_filled_outputs = []
|
||||
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
|
||||
predicted_token = predicted_token_bpe.replace('\u2581', ' ')
|
||||
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
|
||||
predicted_token = predicted_token_bpe.replace("\u2581", " ")
|
||||
if " {0}".format(masked_token) in masked_input:
|
||||
topk_filled_outputs.append((
|
||||
masked_input.replace(
|
||||
' {0}'.format(masked_token), predicted_token
|
||||
),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
))
|
||||
topk_filled_outputs.append(
|
||||
(
|
||||
masked_input.replace(" {0}".format(masked_token), predicted_token),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
)
|
||||
)
|
||||
else:
|
||||
topk_filled_outputs.append((
|
||||
masked_input.replace(masked_token, predicted_token),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
))
|
||||
topk_filled_outputs.append(
|
||||
(masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token,)
|
||||
)
|
||||
return topk_filled_outputs
|
||||
|
||||
|
||||
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
|
||||
model = CamembertForMaskedLM.from_pretrained('camembert-base')
|
||||
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
|
||||
model = CamembertForMaskedLM.from_pretrained("camembert-base")
|
||||
model.eval()
|
||||
|
||||
masked_input = "Le camembert est <mask> :)"
|
||||
|
||||
@@ -36,34 +36,42 @@ from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
|
||||
from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
|
||||
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME,
|
||||
get_linear_schedule_with_warmup)
|
||||
from transformers import (
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
OpenAIGPTTokenizer,
|
||||
AdamW,
|
||||
cached_path,
|
||||
WEIGHTS_NAME,
|
||||
CONFIG_NAME,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def accuracy(out, labels):
|
||||
outputs = np.argmax(out, axis=1)
|
||||
return np.sum(outputs == labels)
|
||||
|
||||
|
||||
def load_rocstories_dataset(dataset_path):
|
||||
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
|
||||
with open(dataset_path, encoding='utf_8') as f:
|
||||
with open(dataset_path, encoding="utf_8") as f:
|
||||
f = csv.reader(f)
|
||||
output = []
|
||||
next(f) # skip the first line
|
||||
next(f) # skip the first line
|
||||
for line in tqdm(f):
|
||||
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1))
|
||||
output.append((" ".join(line[1:5]), line[5], line[6], int(line[-1]) - 1))
|
||||
return output
|
||||
|
||||
|
||||
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
|
||||
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
|
||||
|
||||
@@ -80,56 +88,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
|
||||
for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
|
||||
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
|
||||
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
|
||||
input_ids[i, 0, :len(with_cont1)] = with_cont1
|
||||
input_ids[i, 1, :len(with_cont2)] = with_cont2
|
||||
input_ids[i, 0, : len(with_cont1)] = with_cont1
|
||||
input_ids[i, 1, : len(with_cont2)] = with_cont2
|
||||
mc_token_ids[i, 0] = len(with_cont1) - 1
|
||||
mc_token_ids[i, 1] = len(with_cont2) - 1
|
||||
lm_labels[i, 0, :len(with_cont1)] = with_cont1
|
||||
lm_labels[i, 1, :len(with_cont2)] = with_cont2
|
||||
lm_labels[i, 0, : len(with_cont1)] = with_cont1
|
||||
lm_labels[i, 1, : len(with_cont2)] = with_cont2
|
||||
mc_labels[i] = mc_label
|
||||
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
|
||||
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
|
||||
return tensor_datasets
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str, default='openai-gpt',
|
||||
help='pretrained model name')
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument('--train_dataset', type=str, default='')
|
||||
parser.add_argument('--eval_dataset', type=str, default='')
|
||||
parser.add_argument('--seed', type=int, default=42)
|
||||
parser.add_argument('--num_train_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--eval_batch_size', type=int, default=16)
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument('--max_grad_norm', type=int, default=1)
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training \
|
||||
steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before\
|
||||
performing a backward/update pass.")
|
||||
parser.add_argument('--learning_rate', type=float, default=6.25e-5)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.01)
|
||||
parser.add_argument('--lm_coef', type=float, default=0.9)
|
||||
parser.add_argument('--n_valid', type=int, default=374)
|
||||
parser.add_argument("--model_name", type=str, default="openai-gpt", help="pretrained model name")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--train_dataset", type=str, default="")
|
||||
parser.add_argument("--eval_dataset", type=str, default="")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training \
|
||||
steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before\
|
||||
performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--lr_schedule", type=str, default="warmup_linear")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||||
parser.add_argument("--lm_coef", type=float, default=0.9)
|
||||
parser.add_argument("--n_valid", type=int, default=374)
|
||||
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -152,7 +172,7 @@ def main():
|
||||
# Load tokenizer and model
|
||||
# This loading functions also add new tokens and embeddings called `special tokens`
|
||||
# These new embeddings will be fine-tuned on the RocStories dataset
|
||||
special_tokens = ['_start_', '_delimiter_', '_classify_']
|
||||
special_tokens = ["_start_", "_delimiter_", "_classify_"]
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
|
||||
tokenizer.add_tokens(special_tokens)
|
||||
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
|
||||
@@ -163,6 +183,7 @@ def main():
|
||||
# Load and encode the datasets
|
||||
if not args.train_dataset and not args.eval_dataset:
|
||||
roc_stories = cached_path(ROCSTORIES_URL)
|
||||
|
||||
def tokenize_and_encode(obj):
|
||||
""" Tokenize and encode a nested object """
|
||||
if isinstance(obj, str):
|
||||
@@ -170,6 +191,7 @@ def main():
|
||||
elif isinstance(obj, int):
|
||||
return obj
|
||||
return list(tokenize_and_encode(o) for o in obj)
|
||||
|
||||
logger.info("Encoding dataset...")
|
||||
train_dataset = load_rocstories_dataset(args.train_dataset)
|
||||
eval_dataset = load_rocstories_dataset(args.eval_dataset)
|
||||
@@ -178,8 +200,11 @@ def main():
|
||||
|
||||
# Compute the max input length for the Transformer
|
||||
max_length = model.config.n_positions // 2 - 2
|
||||
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \
|
||||
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
|
||||
input_length = max(
|
||||
len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3
|
||||
for dataset in encoded_datasets
|
||||
for story, cont1, cont2, _ in dataset
|
||||
)
|
||||
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
|
||||
|
||||
# Prepare inputs tensors and dataloaders
|
||||
@@ -198,20 +223,23 @@ def main():
|
||||
if args.do_train:
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps //\
|
||||
(len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader)\
|
||||
// args.gradient_accumulation_steps * args.num_train_epochs
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
|
||||
@@ -230,14 +258,16 @@ def main():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
tr_loss += loss.item()
|
||||
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item()
|
||||
exp_average_loss = (
|
||||
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
|
||||
)
|
||||
nb_tr_steps += 1
|
||||
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
|
||||
|
||||
# Save a trained model
|
||||
if args.do_train:
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself
|
||||
model_to_save = model.module if hasattr(model, "module") else model # Only save the model itself
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
@@ -260,10 +290,12 @@ def main():
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, mc_token_ids, lm_labels, mc_labels = batch
|
||||
with torch.no_grad():
|
||||
_, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels)
|
||||
_, mc_loss, _, mc_logits = model(
|
||||
input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels
|
||||
)
|
||||
|
||||
mc_logits = mc_logits.detach().cpu().numpy()
|
||||
mc_labels = mc_labels.to('cpu').numpy()
|
||||
mc_labels = mc_labels.to("cpu").numpy()
|
||||
tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
|
||||
|
||||
eval_loss += mc_loss.mean().item()
|
||||
@@ -274,10 +306,8 @@ def main():
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||
train_loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy,
|
||||
'train_loss': train_loss}
|
||||
train_loss = tr_loss / nb_tr_steps if args.do_train else None
|
||||
result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_loss": train_loss}
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
@@ -286,5 +316,6 @@ def main():
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -28,8 +28,7 @@ import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -39,31 +38,23 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForMultipleChoice, BertTokenizer)
|
||||
from transformers import WEIGHTS_NAME, BertConfig, BertForMultipleChoice, BertTokenizer
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in [BertConfig]), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
class SwagExample(object):
|
||||
"""A single training/test example for the SWAG dataset."""
|
||||
def __init__(self,
|
||||
swag_id,
|
||||
context_sentence,
|
||||
start_ending,
|
||||
ending_0,
|
||||
ending_1,
|
||||
ending_2,
|
||||
ending_3,
|
||||
label = None):
|
||||
|
||||
def __init__(self, swag_id, context_sentence, start_ending, ending_0, ending_1, ending_2, ending_3, label=None):
|
||||
self.swag_id = swag_id
|
||||
self.context_sentence = context_sentence
|
||||
self.start_ending = start_ending
|
||||
@@ -94,57 +85,49 @@ class SwagExample(object):
|
||||
|
||||
return ", ".join(l)
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(self,
|
||||
example_id,
|
||||
choices_features,
|
||||
label
|
||||
|
||||
):
|
||||
class InputFeatures(object):
|
||||
def __init__(self, example_id, choices_features, label):
|
||||
self.example_id = example_id
|
||||
self.choices_features = [
|
||||
{
|
||||
'input_ids': input_ids,
|
||||
'input_mask': input_mask,
|
||||
'segment_ids': segment_ids
|
||||
}
|
||||
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
|
||||
for _, input_ids, input_mask, segment_ids in choices_features
|
||||
]
|
||||
self.label = label
|
||||
|
||||
|
||||
def read_swag_examples(input_file, is_training=True):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
line = list(unicode(cell, "utf-8") for cell in line)
|
||||
lines.append(line)
|
||||
|
||||
if is_training and lines[0][-1] != 'label':
|
||||
raise ValueError(
|
||||
"For training, the input file must contain a label column."
|
||||
)
|
||||
if is_training and lines[0][-1] != "label":
|
||||
raise ValueError("For training, the input file must contain a label column.")
|
||||
|
||||
examples = [
|
||||
SwagExample(
|
||||
swag_id = line[2],
|
||||
context_sentence = line[4],
|
||||
start_ending = line[5], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
ending_0 = line[7],
|
||||
ending_1 = line[8],
|
||||
ending_2 = line[9],
|
||||
ending_3 = line[10],
|
||||
label = int(line[11]) if is_training else None
|
||||
) for line in lines[1:] # we skip the line with the column names
|
||||
swag_id=line[2],
|
||||
context_sentence=line[4],
|
||||
start_ending=line[5], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
ending_0=line[7],
|
||||
ending_1=line[8],
|
||||
ending_2=line[9],
|
||||
ending_3=line[10],
|
||||
label=int(line[11]) if is_training else None,
|
||||
)
|
||||
for line in lines[1:] # we skip the line with the column names
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
is_training):
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length, is_training):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
# Swag is a multiple choice task. To perform this task using Bert,
|
||||
@@ -204,23 +187,18 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
logger.info("swag_id: {}".format(example.swag_id))
|
||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info("tokens: {}".format(' '.join(tokens)))
|
||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
||||
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
|
||||
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
|
||||
logger.info("tokens: {}".format(" ".join(tokens)))
|
||||
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
||||
logger.info("input_mask: {}".format(" ".join(map(str, input_mask))))
|
||||
logger.info("segment_ids: {}".format(" ".join(map(str, segment_ids))))
|
||||
if is_training:
|
||||
logger.info("label: {}".format(label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id = example.swag_id,
|
||||
choices_features = choices_features,
|
||||
label = label
|
||||
)
|
||||
)
|
||||
features.append(InputFeatures(example_id=example.swag_id, choices_features=choices_features, label=label))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
@@ -237,18 +215,14 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def accuracy(out, labels):
|
||||
outputs = np.argmax(out, axis=1)
|
||||
return np.sum(outputs == labels)
|
||||
|
||||
|
||||
def select_field(features, field):
|
||||
return [
|
||||
[
|
||||
choice[field]
|
||||
for choice in feature.choices_features
|
||||
]
|
||||
for feature in features
|
||||
]
|
||||
return [[choice[field] for choice in feature.choices_features] for feature in features]
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
@@ -258,24 +232,28 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
cached_features_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
"cached_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_swag_examples(input_file)
|
||||
features = convert_examples_to_features(
|
||||
examples, tokenizer, args.max_seq_length, not evaluate)
|
||||
features = convert_examples_to_features(examples, tokenizer, args.max_seq_length, not evaluate)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
@@ -285,21 +263,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
|
||||
all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
|
||||
all_label = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
|
||||
if evaluate:
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_label)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||
else:
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_label)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
return dataset
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
@@ -316,13 +294,18 @@ def train(args, train_dataset, model, tokenizer):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
@@ -336,17 +319,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -360,11 +347,13 @@ def train(args, train_dataset, model, tokenizer):
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
#'token_type_ids': None if args.model_type == 'xlm' else batch[2],
|
||||
'token_type_ids': batch[2],
|
||||
'labels': batch[3]}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
#'token_type_ids': None if args.model_type == 'xlm' else batch[2],
|
||||
"token_type_ids": batch[2],
|
||||
"labels": batch[3],
|
||||
}
|
||||
# if args.model_type in ['xlnet', 'xlm']:
|
||||
# inputs.update({'cls_index': batch[5],
|
||||
# 'p_mask': batch[6]})
|
||||
@@ -372,7 +361,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -393,23 +382,27 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_vocabulary(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -424,6 +417,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||
|
||||
@@ -440,7 +434,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
|
||||
@@ -448,11 +441,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
# 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
'token_type_ids': batch[2],
|
||||
'labels': batch[3]}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
# 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
"token_type_ids": batch[2],
|
||||
"labels": batch[3],
|
||||
}
|
||||
|
||||
# if args.model_type in ['xlnet', 'xlm']:
|
||||
# inputs.update({'cls_index': batch[4],
|
||||
@@ -462,17 +457,16 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
|
||||
logits = logits.detach().cpu().numpy()
|
||||
label_ids = inputs['labels'].to('cpu').numpy()
|
||||
label_ids = inputs["labels"].to("cpu").numpy()
|
||||
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||
eval_accuracy += tmp_eval_accuracy
|
||||
|
||||
nb_eval_steps += 1
|
||||
nb_eval_examples += inputs['input_ids'].size(0)
|
||||
nb_eval_examples += inputs["input_ids"].size(0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy}
|
||||
result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy}
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
@@ -483,92 +477,144 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
||||
help="SWAG csv for training. E.g., train.csv")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SWAG csv for predictions. E.g., val.csv or test.csv")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
parser.add_argument(
|
||||
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="SWAG csv for predictions. E.g., val.csv or test.csv",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -580,16 +626,24 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -601,8 +655,12 @@ def main():
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -617,7 +675,6 @@ def main():
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
# Create output directory if needed
|
||||
@@ -627,19 +684,20 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
@@ -650,14 +708,16 @@ def main():
|
||||
checkpoints = [args.model_name_or_path]
|
||||
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
tokenizer = tokenizer_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
@@ -665,7 +725,7 @@ def main():
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
@@ -30,44 +30,36 @@ import torch
|
||||
|
||||
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
|
||||
help='pretrained model name')
|
||||
parser.add_argument('--split', type=str, default='test',
|
||||
choices=['all', 'valid', 'test'],
|
||||
help='which split to evaluate')
|
||||
parser.add_argument('--batch_size', type=int, default=10,
|
||||
help='batch size')
|
||||
parser.add_argument('--tgt_len', type=int, default=128,
|
||||
help='number of tokens to predict')
|
||||
parser.add_argument('--ext_len', type=int, default=0,
|
||||
help='length of the extended context')
|
||||
parser.add_argument('--mem_len', type=int, default=1600,
|
||||
help='length of the retained previous heads')
|
||||
parser.add_argument('--clamp_len', type=int, default=1000,
|
||||
help='max positional embedding index')
|
||||
parser.add_argument('--no_cuda', action='store_true',
|
||||
help='Do not use CUDA even though CUA is available')
|
||||
parser.add_argument('--work_dir', type=str, required=True,
|
||||
help='path to the work_dir')
|
||||
parser.add_argument('--no_log', action='store_true',
|
||||
help='do not log the eval result')
|
||||
parser.add_argument('--same_length', action='store_true',
|
||||
help='set same length attention with masking')
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
|
||||
parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
|
||||
parser.add_argument(
|
||||
"--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=10, help="batch size")
|
||||
parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
|
||||
parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
|
||||
parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
|
||||
parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
|
||||
parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
|
||||
parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
|
||||
parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||
assert args.ext_len >= 0, "extended context length must be non-negative"
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -84,17 +76,18 @@ def main():
|
||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||
ntokens = len(corpus.vocab)
|
||||
|
||||
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
|
||||
te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
|
||||
|
||||
# Load a pre-trained model
|
||||
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
|
||||
model = model.to(device)
|
||||
|
||||
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
||||
logger.info(
|
||||
"Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
|
||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
|
||||
)
|
||||
)
|
||||
|
||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||
if args.clamp_len > 0:
|
||||
@@ -108,7 +101,7 @@ def main():
|
||||
def evaluate(eval_iter):
|
||||
# Turn on evaluation mode which disables dropout.
|
||||
model.eval()
|
||||
total_len, total_loss = 0, 0.
|
||||
total_len, total_loss = 0, 0.0
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
mems = None
|
||||
@@ -119,35 +112,34 @@ def main():
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
total_time = time.time() - start_time
|
||||
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||
total_time, 1000 * total_time / (idx+1)))
|
||||
logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
|
||||
return total_loss / total_len
|
||||
|
||||
# Run on test data.
|
||||
if args.split == 'all':
|
||||
if args.split == "all":
|
||||
test_loss = evaluate(te_iter)
|
||||
valid_loss = evaluate(va_iter)
|
||||
elif args.split == 'valid':
|
||||
elif args.split == "valid":
|
||||
valid_loss = evaluate(va_iter)
|
||||
test_loss = None
|
||||
elif args.split == 'test':
|
||||
elif args.split == "test":
|
||||
test_loss = evaluate(te_iter)
|
||||
valid_loss = None
|
||||
|
||||
def format_log(loss, split):
|
||||
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
|
||||
split, loss, math.exp(loss))
|
||||
log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
|
||||
return log_str
|
||||
|
||||
log_str = ''
|
||||
log_str = ""
|
||||
if valid_loss is not None:
|
||||
log_str += format_log(valid_loss, 'valid')
|
||||
log_str += format_log(valid_loss, "valid")
|
||||
if test_loss is not None:
|
||||
log_str += format_log(test_loss, 'test')
|
||||
log_str += format_log(test_loss, "test")
|
||||
|
||||
logger.info('=' * 100)
|
||||
logger.info("=" * 100)
|
||||
logger.info(log_str)
|
||||
logger.info('=' * 100)
|
||||
logger.info("=" * 100)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -40,14 +40,12 @@ from utils import logger
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
|
||||
|
||||
class Distiller:
|
||||
def __init__(self,
|
||||
params: dict,
|
||||
dataset: LmSeqsDataset,
|
||||
token_probs: torch.tensor,
|
||||
student: nn.Module,
|
||||
teacher: nn.Module):
|
||||
logger.info('Initializing Distiller')
|
||||
def __init__(
|
||||
self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
|
||||
):
|
||||
logger.info("Initializing Distiller")
|
||||
self.params = params
|
||||
self.dump_path = params.dump_path
|
||||
self.multi_gpu = params.multi_gpu
|
||||
@@ -70,12 +68,10 @@ class Distiller:
|
||||
else:
|
||||
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
|
||||
|
||||
self.dataloader = DataLoader(dataset=dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.batch_sequences)
|
||||
self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
|
||||
|
||||
self.temperature = params.temperature
|
||||
assert self.temperature > 0.
|
||||
assert self.temperature > 0.0
|
||||
|
||||
self.alpha_ce = params.alpha_ce
|
||||
self.alpha_mlm = params.alpha_mlm
|
||||
@@ -85,18 +81,18 @@ class Distiller:
|
||||
|
||||
self.mlm = params.mlm
|
||||
if self.mlm:
|
||||
logger.info(f'Using MLM loss for LM step.')
|
||||
logger.info(f"Using MLM loss for LM step.")
|
||||
self.mlm_mask_prop = params.mlm_mask_prop
|
||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
|
||||
self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
|
||||
self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
|
||||
if self.fp16:
|
||||
self.pred_probs = self.pred_probs.half()
|
||||
self.token_probs = self.token_probs.half()
|
||||
else:
|
||||
logger.info(f'Using CLM loss for LM step.')
|
||||
logger.info(f"Using CLM loss for LM step.")
|
||||
|
||||
self.epoch = 0
|
||||
self.n_iter = 0
|
||||
@@ -107,38 +103,54 @@ class Distiller:
|
||||
self.last_loss_ce = 0
|
||||
self.last_loss_mlm = 0
|
||||
self.last_loss_clm = 0
|
||||
if self.alpha_mse > 0.: self.last_loss_mse = 0
|
||||
if self.alpha_cos > 0.: self.last_loss_cos = 0
|
||||
if self.alpha_mse > 0.0:
|
||||
self.last_loss_mse = 0
|
||||
if self.alpha_cos > 0.0:
|
||||
self.last_loss_cos = 0
|
||||
self.last_log = 0
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
if self.alpha_mse > 0.:
|
||||
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
||||
if self.alpha_cos > 0.:
|
||||
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')
|
||||
if self.alpha_mse > 0.0:
|
||||
self.mse_loss_fct = nn.MSELoss(reduction="sum")
|
||||
if self.alpha_cos > 0.0:
|
||||
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
|
||||
|
||||
logger.info('--- Initializing model optimizer')
|
||||
logger.info("--- Initializing model optimizer")
|
||||
assert params.gradient_accumulation_steps >= 1
|
||||
self.num_steps_epoch = len(self.dataloader)
|
||||
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
num_train_optimization_steps = (
|
||||
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||
)
|
||||
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
|
||||
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
|
||||
{
|
||||
"params": [
|
||||
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": params.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
|
||||
logger.info(
|
||||
"------ Number of trainable parameters (student): %i"
|
||||
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
|
||||
)
|
||||
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
||||
self.optimizer = AdamW(optimizer_grouped_parameters,
|
||||
lr=params.learning_rate,
|
||||
eps=params.adam_epsilon,
|
||||
betas=(0.9, 0.98))
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
|
||||
)
|
||||
|
||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||
self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=num_train_optimization_steps)
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
|
||||
)
|
||||
|
||||
if self.fp16:
|
||||
try:
|
||||
@@ -146,33 +158,36 @@ class Distiller:
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
||||
self.student, self.optimizer = amp.initialize(self.student,
|
||||
self.optimizer,
|
||||
opt_level=self.params.fp16_opt_level)
|
||||
self.student, self.optimizer = amp.initialize(
|
||||
self.student, self.optimizer, opt_level=self.params.fp16_opt_level
|
||||
)
|
||||
self.teacher = self.teacher.half()
|
||||
|
||||
if self.multi_gpu:
|
||||
if self.fp16:
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student)
|
||||
else:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||
self.student = DistributedDataParallel(self.student,
|
||||
device_ids=[params.local_rank],
|
||||
output_device=params.local_rank,
|
||||
find_unused_parameters=True)
|
||||
self.student = DistributedDataParallel(
|
||||
self.student,
|
||||
device_ids=[params.local_rank],
|
||||
output_device=params.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
self.is_master = params.is_master
|
||||
if self.is_master:
|
||||
logger.info('--- Initializing Tensorboard')
|
||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
|
||||
self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
|
||||
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)
|
||||
logger.info("--- Initializing Tensorboard")
|
||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
|
||||
self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
|
||||
self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
|
||||
|
||||
def prepare_batch_mlm(self,
|
||||
batch):
|
||||
def prepare_batch_mlm(self, batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
|
||||
|
||||
@@ -192,7 +207,7 @@ class Distiller:
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
||||
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||
|
||||
bs, max_seq_len = token_ids.size()
|
||||
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
@@ -200,11 +215,13 @@ class Distiller:
|
||||
x_prob = self.token_probs[token_ids.flatten()]
|
||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||
pred_mask = torch.zeros(
|
||||
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
|
||||
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||
pred_mask[tgt_ids] = 1
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
|
||||
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0
|
||||
pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
|
||||
|
||||
# mask a number of words == 0 [8] (faster with fp16)
|
||||
if self.fp16:
|
||||
@@ -213,26 +230,29 @@ class Distiller:
|
||||
pred_mask = pred_mask.view(-1)
|
||||
n2 = max(n1 % 8, 8 * (n1 // 8))
|
||||
if n2 != n1:
|
||||
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0
|
||||
pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
|
||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
|
||||
|
||||
_token_ids_real = token_ids[pred_mask]
|
||||
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
|
||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
|
||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
|
||||
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
||||
_token_ids = (
|
||||
_token_ids_mask * (probs == 0).long()
|
||||
+ _token_ids_real * (probs == 1).long()
|
||||
+ _token_ids_rand * (probs == 2).long()
|
||||
)
|
||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||
|
||||
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, mlm_labels
|
||||
|
||||
def prepare_batch_clm(self,
|
||||
batch):
|
||||
def prepare_batch_clm(self, batch):
|
||||
"""
|
||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
|
||||
|
||||
@@ -252,18 +272,16 @@ class Distiller:
|
||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||
assert token_ids.size(0) == lengths.size(0)
|
||||
|
||||
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
||||
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||
|
||||
# sanity checks
|
||||
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
|
||||
|
||||
return token_ids, attn_mask, clm_labels
|
||||
|
||||
def round_batch(self,
|
||||
x: torch.tensor,
|
||||
lengths: torch.tensor):
|
||||
def round_batch(self, x: torch.tensor, lengths: torch.tensor):
|
||||
"""
|
||||
For float16 only.
|
||||
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
||||
@@ -299,9 +317,9 @@ class Distiller:
|
||||
pad = 8 - (ml1 % 8)
|
||||
ml2 = ml1 + pad
|
||||
if self.mlm:
|
||||
pad_id = self.params.special_tok_ids['pad_token']
|
||||
pad_id = self.params.special_tok_ids["pad_token"]
|
||||
else:
|
||||
pad_id = self.params.special_tok_ids['unk_token']
|
||||
pad_id = self.params.special_tok_ids["unk_token"]
|
||||
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||
x = torch.cat([x, padding_tensor], 1)
|
||||
assert x.size() == (bs2, ml2)
|
||||
@@ -314,20 +332,22 @@ class Distiller:
|
||||
"""
|
||||
The real training loop.
|
||||
"""
|
||||
if self.is_master: logger.info('Starting training')
|
||||
if self.is_master:
|
||||
logger.info("Starting training")
|
||||
self.last_log = time.time()
|
||||
self.student.train()
|
||||
self.teacher.eval()
|
||||
|
||||
for _ in range(self.params.n_epoch):
|
||||
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||
if self.is_master:
|
||||
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||
if self.multi_gpu:
|
||||
torch.distributed.barrier()
|
||||
|
||||
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||
for batch in iter_bar:
|
||||
if self.params.n_gpu > 0:
|
||||
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
|
||||
batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
|
||||
|
||||
if self.mlm:
|
||||
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
|
||||
@@ -336,22 +356,21 @@ class Distiller:
|
||||
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
|
||||
|
||||
iter_bar.update()
|
||||
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
|
||||
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
|
||||
iter_bar.set_postfix(
|
||||
{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
|
||||
)
|
||||
iter_bar.close()
|
||||
|
||||
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
|
||||
if self.is_master:
|
||||
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||
self.end_epoch()
|
||||
|
||||
if self.is_master:
|
||||
logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
|
||||
self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
|
||||
logger.info('Training is finished')
|
||||
logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
|
||||
self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
|
||||
logger.info("Training is finished")
|
||||
|
||||
def step(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor,
|
||||
lm_labels: torch.tensor):
|
||||
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
|
||||
"""
|
||||
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||
and possibly a parameter update (depending on the gradient accumulation).
|
||||
@@ -363,78 +382,91 @@ class Distiller:
|
||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||
"""
|
||||
if self.mlm:
|
||||
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
s_logits, s_hidden_states = self.student(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
||||
t_logits, t_hidden_states = self.teacher(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
) # (bs, seq_length, voc_size)
|
||||
else:
|
||||
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
s_logits, _, s_hidden_states = self.student(
|
||||
input_ids=input_ids, attention_mask=None
|
||||
) # (bs, seq_length, voc_size)
|
||||
with torch.no_grad():
|
||||
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
||||
t_logits, _, t_hidden_states = self.teacher(
|
||||
input_ids=input_ids, attention_mask=None
|
||||
) # (bs, seq_length, voc_size)
|
||||
assert s_logits.size() == t_logits.size()
|
||||
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
|
||||
if self.params.restrict_ce_to_mask:
|
||||
mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
|
||||
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
|
||||
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
|
||||
loss = self.alpha_ce*loss_ce
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
loss = self.alpha_ce * loss_ce
|
||||
|
||||
if self.alpha_mlm > 0.:
|
||||
if self.alpha_mlm > 0.0:
|
||||
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
|
||||
loss += self.alpha_mlm * loss_mlm
|
||||
if self.alpha_clm > 0.:
|
||||
if self.alpha_clm > 0.0:
|
||||
shift_logits = s_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss += self.alpha_clm * loss_clm
|
||||
|
||||
if self.alpha_mse > 0.:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
|
||||
if self.alpha_mse > 0.0:
|
||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
|
||||
0
|
||||
) # Reproducing batchmean reduction
|
||||
loss += self.alpha_mse * loss_mse
|
||||
if self.alpha_cos > 0.:
|
||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
||||
if self.alpha_cos > 0.0:
|
||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
||||
assert s_hidden_states.size() == t_hidden_states.size()
|
||||
dim = s_hidden_states.size(-1)
|
||||
|
||||
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
|
||||
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
|
||||
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
|
||||
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
|
||||
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
|
||||
|
||||
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
|
||||
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
|
||||
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
|
||||
loss += self.alpha_cos * loss_cos
|
||||
|
||||
self.total_loss_epoch += loss.item()
|
||||
self.last_loss = loss.item()
|
||||
self.last_loss_ce = loss_ce.item()
|
||||
if self.alpha_mlm > 0.:
|
||||
if self.alpha_mlm > 0.0:
|
||||
self.last_loss_mlm = loss_mlm.item()
|
||||
if self.alpha_clm > 0.:
|
||||
if self.alpha_clm > 0.0:
|
||||
self.last_loss_clm = loss_clm.item()
|
||||
if self.alpha_mse > 0.:
|
||||
if self.alpha_mse > 0.0:
|
||||
self.last_loss_mse = loss_mse.item()
|
||||
if self.alpha_cos > 0.:
|
||||
if self.alpha_cos > 0.0:
|
||||
self.last_loss_cos = loss_cos.item()
|
||||
|
||||
self.optimize(loss)
|
||||
|
||||
self.n_sequences_epoch += input_ids.size(0)
|
||||
|
||||
def optimize(self,
|
||||
loss):
|
||||
def optimize(self, loss):
|
||||
"""
|
||||
Normalization on the loss (gradient accumulation or distributed training), followed by
|
||||
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
||||
@@ -442,7 +474,7 @@ class Distiller:
|
||||
"""
|
||||
# Check for NaN
|
||||
if (loss != loss).data.any():
|
||||
logger.error('NaN detected')
|
||||
logger.error("NaN detected")
|
||||
exit()
|
||||
|
||||
if self.multi_gpu:
|
||||
@@ -452,6 +484,7 @@ class Distiller:
|
||||
|
||||
if self.fp16:
|
||||
from apex import amp
|
||||
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
@@ -488,53 +521,84 @@ class Distiller:
|
||||
return
|
||||
|
||||
for param_name, param in self.student.named_parameters():
|
||||
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
|
||||
)
|
||||
if param.grad is None:
|
||||
continue
|
||||
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/cum_avg_loss_epoch",
|
||||
scalar_value=self.total_loss_epoch / self.n_iter,
|
||||
global_step=self.n_total_iter,
|
||||
)
|
||||
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
|
||||
if self.alpha_mlm > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
|
||||
if self.alpha_clm > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
|
||||
if self.alpha_mse > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||
if self.alpha_cos > 0.:
|
||||
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_mlm > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_clm > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_mse > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
|
||||
)
|
||||
if self.alpha_cos > 0.0:
|
||||
self.tensorboard.add_scalar(
|
||||
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="global/memory_usage",
|
||||
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
|
||||
global_step=self.n_total_iter,
|
||||
)
|
||||
self.tensorboard.add_scalar(
|
||||
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
|
||||
)
|
||||
|
||||
def end_epoch(self):
|
||||
"""
|
||||
Finally arrived at the end of epoch (full pass on dataset).
|
||||
Do some tensorboard logging and checkpoint saving.
|
||||
"""
|
||||
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')
|
||||
logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
|
||||
|
||||
if self.is_master:
|
||||
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
|
||||
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)
|
||||
self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
|
||||
self.tensorboard.add_scalar(
|
||||
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
|
||||
)
|
||||
|
||||
self.epoch += 1
|
||||
self.n_sequences_epoch = 0
|
||||
self.n_iter = 0
|
||||
self.total_loss_epoch = 0
|
||||
|
||||
def save_checkpoint(self,
|
||||
checkpoint_name: str = 'checkpoint.pth'):
|
||||
def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
|
||||
"""
|
||||
Save the current state. Only by the master process.
|
||||
"""
|
||||
if not self.is_master:
|
||||
return
|
||||
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
|
||||
mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
|
||||
mdl_to_save.config.save_pretrained(self.dump_path)
|
||||
state_dict = mdl_to_save.state_dict()
|
||||
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
||||
|
||||
@@ -23,12 +23,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler
|
||||
|
||||
from utils import logger
|
||||
|
||||
|
||||
def _quantize(x, bins):
|
||||
bins = copy.deepcopy(bins)
|
||||
bins = sorted(bins)
|
||||
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
||||
return quantized
|
||||
|
||||
|
||||
def create_lengths_groups(lengths, k=0):
|
||||
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
|
||||
groups = _quantize(lengths, bins)
|
||||
@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0):
|
||||
logger.info("Count of instances per bin: {}".format(counts))
|
||||
return groups
|
||||
|
||||
|
||||
class GroupedBatchSampler(BatchSampler):
|
||||
"""
|
||||
Wraps another sampler to yield a mini-batch of indices.
|
||||
@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler):
|
||||
0, i.e. they must be in the range [0, num_groups).
|
||||
batch_size (int): Size of mini-batch.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, group_ids, batch_size):
|
||||
if not isinstance(sampler, Sampler):
|
||||
raise ValueError(
|
||||
"sampler should be an instance of "
|
||||
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
)
|
||||
self.sampler = sampler
|
||||
self.group_ids = group_ids
|
||||
@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler):
|
||||
buffer_per_group[group_id].append(idx)
|
||||
samples_per_group[group_id].append(idx)
|
||||
if len(buffer_per_group[group_id]) == self.batch_size:
|
||||
yield buffer_per_group[group_id] #TODO
|
||||
yield buffer_per_group[group_id] # TODO
|
||||
num_batches += 1
|
||||
del buffer_per_group[group_id]
|
||||
assert len(buffer_per_group[group_id]) < self.batch_size
|
||||
@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler):
|
||||
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
|
||||
batch_idx.extend(idxs)
|
||||
if len(batch_idx) >= self.batch_size:
|
||||
yield batch_idx[:self.batch_size]
|
||||
batch_idx = batch_idx[self.batch_size:]
|
||||
yield batch_idx[: self.batch_size]
|
||||
batch_idx = batch_idx[self.batch_size :]
|
||||
num_remaining -= 1
|
||||
if len(batch_idx) > 0:
|
||||
yield batch_idx
|
||||
|
||||
@@ -21,6 +21,7 @@ from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
from utils import logger
|
||||
|
||||
|
||||
class LmSeqsDataset(Dataset):
|
||||
"""Custom Dataset wrapping language modeling sequences.
|
||||
|
||||
@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset):
|
||||
data: `List[np.array[int]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
data):
|
||||
def __init__(self, params, data):
|
||||
self.params = params
|
||||
|
||||
self.token_ids = np.array(data)
|
||||
@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset):
|
||||
"""
|
||||
max_len = self.params.max_model_input_size
|
||||
indices = self.lengths > max_len
|
||||
logger.info(f'Splitting {sum(indices)} too long sequences.')
|
||||
logger.info(f"Splitting {sum(indices)} too long sequences.")
|
||||
|
||||
def divide_chunks(l, n):
|
||||
return [l[i:i + n] for i in range(0, len(l), n)]
|
||||
return [l[i : i + n] for i in range(0, len(l), n)]
|
||||
|
||||
new_tok_ids = []
|
||||
new_lengths = []
|
||||
if self.params.mlm:
|
||||
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
|
||||
cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
|
||||
else:
|
||||
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token']
|
||||
cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
|
||||
|
||||
for seq_, len_ in zip(self.token_ids, self.lengths):
|
||||
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
|
||||
@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset):
|
||||
new_lengths.append(len_)
|
||||
else:
|
||||
sub_seqs = []
|
||||
for sub_s in divide_chunks(seq_, max_len-2):
|
||||
for sub_s in divide_chunks(seq_, max_len - 2):
|
||||
if sub_s[0] != cls_id:
|
||||
sub_s = np.insert(sub_s, 0, cls_id)
|
||||
if sub_s[-1] != sep_id:
|
||||
@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset):
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.')
|
||||
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset):
|
||||
"""
|
||||
if not self.params.is_master:
|
||||
return
|
||||
logger.info(f'{len(self)} sequences')
|
||||
logger.info(f"{len(self)} sequences")
|
||||
# data_len = sum(self.lengths)
|
||||
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
||||
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||
@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset):
|
||||
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
|
||||
|
||||
def batch_sequences(self,
|
||||
batch):
|
||||
def batch_sequences(self, batch):
|
||||
"""
|
||||
Do the padding and transform into torch.tensor.
|
||||
"""
|
||||
@@ -139,13 +137,13 @@ class LmSeqsDataset(Dataset):
|
||||
|
||||
# Pad token ids
|
||||
if self.params.mlm:
|
||||
pad_idx = self.params.special_tok_ids['pad_token']
|
||||
pad_idx = self.params.special_tok_ids["pad_token"]
|
||||
else:
|
||||
pad_idx = self.params.special_tok_ids['unk_token']
|
||||
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids]
|
||||
pad_idx = self.params.special_tok_ids["unk_token"]
|
||||
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
|
||||
assert len(tk_) == len(token_ids)
|
||||
assert all(len(t) == max_seq_len_ for t in tk_)
|
||||
|
||||
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
|
||||
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
|
||||
lg_t = torch.tensor(lengths) # (bs)
|
||||
return tk_t, lg_t
|
||||
|
||||
@@ -25,8 +25,7 @@ import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
@@ -38,19 +37,32 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from ..utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
from ..utils_squad import (
|
||||
read_squad_examples,
|
||||
convert_examples_to_features,
|
||||
RawResult,
|
||||
write_predictions,
|
||||
RawResultExtended,
|
||||
write_predictions_extended,
|
||||
)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
@@ -59,16 +71,18 @@ from ..utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -76,9 +90,11 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
@@ -95,13 +111,18 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
@@ -115,17 +136,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -141,40 +166,47 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
if teacher is not None:
|
||||
teacher.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"start_positions": batch[3],
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
outputs = model(**inputs)
|
||||
loss, start_logits_stu, end_logits_stu = outputs
|
||||
|
||||
# Distillation loss
|
||||
if teacher is not None:
|
||||
if 'token_type_ids' not in inputs:
|
||||
inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2]
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
|
||||
with torch.no_grad():
|
||||
start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'],
|
||||
token_type_ids=inputs['token_type_ids'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
start_logits_tea, end_logits_tea = teacher(
|
||||
input_ids=inputs["input_ids"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
assert start_logits_tea.size() == start_logits_stu.size()
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction='batchmean')
|
||||
loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1),
|
||||
F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
|
||||
loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1),
|
||||
F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
|
||||
loss_ce = (loss_start + loss_end)/2.
|
||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
loss_start = loss_fct(
|
||||
F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
F.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature ** 2)
|
||||
loss_end = loss_fct(
|
||||
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
F.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature ** 2)
|
||||
loss_ce = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_ce*loss_ce + args.alpha_squad*loss
|
||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -195,22 +227,26 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -246,32 +282,31 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(unique_id = unique_id,
|
||||
start_top_log_probs = to_list(outputs[0][i]),
|
||||
start_top_index = to_list(outputs[1][i]),
|
||||
end_top_log_probs = to_list(outputs[2][i]),
|
||||
end_top_index = to_list(outputs[3][i]),
|
||||
cls_logits = to_list(outputs[4][i]))
|
||||
result = RawResultExtended(
|
||||
unique_id=unique_id,
|
||||
start_top_log_probs=to_list(outputs[0][i]),
|
||||
start_top_index=to_list(outputs[1][i]),
|
||||
end_top_log_probs=to_list(outputs[2][i]),
|
||||
end_top_index=to_list(outputs[3][i]),
|
||||
cls_logits=to_list(outputs[4][i]),
|
||||
)
|
||||
else:
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
result = RawResult(
|
||||
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Compute predictions
|
||||
@@ -282,23 +317,44 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
write_predictions_extended(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.predict_file,
|
||||
model.config.start_n_top,
|
||||
model.config.end_n_top,
|
||||
args.version_2_with_negative,
|
||||
tokenizer,
|
||||
args.verbose_logging,
|
||||
)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
write_predictions(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
args.do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.verbose_logging,
|
||||
args.version_2_with_negative,
|
||||
args.null_score_diff_threshold,
|
||||
)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
pred_file=output_prediction_file,
|
||||
na_prob_file=output_null_log_odds_file)
|
||||
evaluate_options = EVAL_OPTS(
|
||||
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
|
||||
)
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
return results
|
||||
|
||||
@@ -309,24 +365,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
cached_features_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
"cached_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate)
|
||||
examples = read_squad_examples(
|
||||
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
@@ -342,14 +404,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
|
||||
)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids,
|
||||
all_input_mask,
|
||||
all_segment_ids,
|
||||
all_start_positions,
|
||||
all_end_positions,
|
||||
all_cls_index,
|
||||
all_p_mask,
|
||||
)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
@@ -360,121 +429,213 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
parser.add_argument(
|
||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
# Distillation parameters (optional)
|
||||
parser.add_argument('--teacher_type', default=None, type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.")
|
||||
parser.add_argument('--teacher_name_or_path', default=None, type=str,
|
||||
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.")
|
||||
parser.add_argument('--alpha_ce', default=0.5, type=float,
|
||||
help="Distillation loss linear weight. Only for distillation.")
|
||||
parser.add_argument('--alpha_squad', default=0.5, type=float,
|
||||
help="True SQuAD loss linear weight. Only for distillation.")
|
||||
parser.add_argument('--temperature', default=2.0, type=float,
|
||||
help="Distillation temperature. Only for distillation.")
|
||||
parser.add_argument(
|
||||
"--teacher_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--teacher_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
|
||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
parser.add_argument(
|
||||
"--version_2_with_negative",
|
||||
action="store_true",
|
||||
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--null_score_diff_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||
)
|
||||
|
||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--doc_stride", default=128, type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
||||
parser.add_argument("--max_query_length", default=64, type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
default=128,
|
||||
type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_query_length",
|
||||
default=64,
|
||||
type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--n_best_size", default=20, type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.")
|
||||
parser.add_argument("--verbose_logging", action='store_true',
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--n_best_size",
|
||||
default=20,
|
||||
type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose_logging",
|
||||
action="store_true",
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -486,16 +647,24 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -506,27 +675,34 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
assert args.alpha_ce > 0.
|
||||
assert args.alpha_ce + args.alpha_squad > 0.
|
||||
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT."
|
||||
assert args.alpha_ce > 0.0
|
||||
assert args.alpha_ce + args.alpha_squad > 0.0
|
||||
assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path,
|
||||
config=teacher_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
teacher_config = teacher_config_class.from_pretrained(
|
||||
args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
teacher = teacher_model_class.from_pretrained(
|
||||
args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
@@ -544,7 +720,6 @@ def main():
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -554,41 +729,44 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.output_dir, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
@@ -23,68 +23,65 @@ import numpy as np
|
||||
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
|
||||
import logging
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
|
||||
parser.add_argument('--file_path', type=str, default='data/dump.txt',
|
||||
help='The path to the data.')
|
||||
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2'])
|
||||
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
|
||||
help="The tokenizer to use.")
|
||||
parser.add_argument('--dump_file', type=str, default='data/dump',
|
||||
help='The dump file prefix.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
|
||||
)
|
||||
parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
|
||||
parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
|
||||
parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
|
||||
parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
|
||||
if args.tokenizer_type == 'bert':
|
||||
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
|
||||
if args.tokenizer_type == "bert":
|
||||
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]`
|
||||
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]`
|
||||
elif args.tokenizer_type == 'roberta':
|
||||
bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
|
||||
sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
|
||||
elif args.tokenizer_type == "roberta":
|
||||
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map['cls_token'] # `<s>`
|
||||
sep = tokenizer.special_tokens_map['sep_token'] # `</s>`
|
||||
elif args.tokenizer_type == 'gpt2':
|
||||
bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
|
||||
sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
|
||||
elif args.tokenizer_type == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
|
||||
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>`
|
||||
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>`
|
||||
bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
|
||||
sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
|
||||
|
||||
logger.info(f'Loading text from {args.file_path}')
|
||||
with open(args.file_path, 'r', encoding='utf8') as fp:
|
||||
logger.info(f"Loading text from {args.file_path}")
|
||||
with open(args.file_path, "r", encoding="utf8") as fp:
|
||||
data = fp.readlines()
|
||||
|
||||
|
||||
logger.info(f'Start encoding')
|
||||
logger.info(f'{len(data)} examples to process.')
|
||||
logger.info(f"Start encoding")
|
||||
logger.info(f"{len(data)} examples to process.")
|
||||
|
||||
rslt = []
|
||||
iter = 0
|
||||
interval = 10000
|
||||
start = time.time()
|
||||
for text in data:
|
||||
text = f'{bos} {text.strip()} {sep}'
|
||||
text = f"{bos} {text.strip()} {sep}"
|
||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
rslt.append(token_ids)
|
||||
|
||||
iter += 1
|
||||
if iter % interval == 0:
|
||||
end = time.time()
|
||||
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl')
|
||||
logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
|
||||
start = time.time()
|
||||
logger.info('Finished binarization')
|
||||
logger.info(f'{len(data)} examples processed.')
|
||||
logger.info("Finished binarization")
|
||||
logger.info(f"{len(data)} examples processed.")
|
||||
|
||||
|
||||
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle'
|
||||
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
|
||||
rslt_ = [np.uint16(d) for d in rslt]
|
||||
random.shuffle(rslt_)
|
||||
logger.info(f'Dump to {dp_file}')
|
||||
with open(dp_file, 'wb') as handle:
|
||||
logger.info(f"Dump to {dp_file}")
|
||||
with open(dp_file, "wb") as handle:
|
||||
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
|
||||
@@ -20,70 +20,80 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation")
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
||||
)
|
||||
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
||||
parser.add_argument("--model_name", default='roberta-large', type=str)
|
||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str)
|
||||
parser.add_argument("--vocab_transform", action='store_true')
|
||||
parser.add_argument("--model_name", default="roberta-large", type=str)
|
||||
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
|
||||
parser.add_argument("--vocab_transform", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.model_type == 'roberta':
|
||||
if args.model_type == "roberta":
|
||||
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = 'roberta'
|
||||
elif args.model_type == 'gpt2':
|
||||
prefix = "roberta"
|
||||
elif args.model_type == "gpt2":
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name)
|
||||
prefix = 'transformer'
|
||||
prefix = "transformer"
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
### Embeddings ###
|
||||
if args.model_type == 'gpt2':
|
||||
for param_name in ['wte.weight', 'wpe.weight']:
|
||||
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}']
|
||||
if args.model_type == "gpt2":
|
||||
for param_name in ["wte.weight", "wpe.weight"]:
|
||||
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||
else:
|
||||
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']:
|
||||
param_name = f'{prefix}.embeddings.{w}.weight'
|
||||
for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
|
||||
param_name = f"{prefix}.embeddings.{w}.weight"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
for w in ['weight', 'bias']:
|
||||
param_name = f'{prefix}.embeddings.LayerNorm.{w}'
|
||||
for w in ["weight", "bias"]:
|
||||
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
|
||||
### Transformer Blocks ###
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
if args.model_type == 'gpt2':
|
||||
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \
|
||||
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}']
|
||||
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias']
|
||||
if args.model_type == "gpt2":
|
||||
for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
|
||||
f"{prefix}.h.{teacher_idx}.{layer}.{w}"
|
||||
]
|
||||
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
|
||||
else:
|
||||
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value',
|
||||
'attention.output.dense', 'attention.output.LayerNorm',
|
||||
'intermediate.dense', 'output.dense', 'output.LayerNorm']:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}']
|
||||
for layer in [
|
||||
"attention.self.query",
|
||||
"attention.self.key",
|
||||
"attention.self.value",
|
||||
"attention.output.dense",
|
||||
"attention.output.LayerNorm",
|
||||
"intermediate.dense",
|
||||
"output.dense",
|
||||
"output.LayerNorm",
|
||||
]:
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
### Language Modeling Head ###s
|
||||
if args.model_type == 'roberta':
|
||||
for layer in ['lm_head.decoder.weight', 'lm_head.bias']:
|
||||
compressed_sd[f'{layer}'] = state_dict[f'{layer}']
|
||||
if args.model_type == "roberta":
|
||||
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||
if args.vocab_transform:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}']
|
||||
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
|
||||
elif args.model_type == 'gpt2':
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}']
|
||||
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight']
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
|
||||
compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
|
||||
elif args.model_type == "gpt2":
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
|
||||
compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
|
||||
|
||||
print(f'N layers selected for distillation: {std_idx}')
|
||||
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
|
||||
print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
|
||||
@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
||||
)
|
||||
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||
parser.add_argument("--model_name", default='bert-base-uncased', type=str)
|
||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
|
||||
parser.add_argument("--vocab_transform", action='store_true')
|
||||
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
||||
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
|
||||
parser.add_argument("--vocab_transform", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.model_type == 'bert':
|
||||
if args.model_type == "bert":
|
||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||
prefix = 'bert'
|
||||
prefix = "bert"
|
||||
else:
|
||||
raise ValueError(f'args.model_type should be "bert".')
|
||||
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
for w in ['word_embeddings', 'position_embeddings']:
|
||||
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
|
||||
state_dict[f'{prefix}.embeddings.{w}.weight']
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
|
||||
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
|
||||
for w in ["word_embeddings", "position_embeddings"]:
|
||||
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
|
||||
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}']
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
|
||||
]
|
||||
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
|
||||
]
|
||||
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}']
|
||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
|
||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
|
||||
]
|
||||
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
|
||||
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
|
||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
|
||||
compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
|
||||
compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
|
||||
if args.vocab_transform:
|
||||
for w in ['weight', 'bias']:
|
||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
|
||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
|
||||
for w in ["weight", "bias"]:
|
||||
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
|
||||
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
|
||||
|
||||
print(f'N layers selected for distillation: {std_idx}')
|
||||
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
||||
print(f"N layers selected for distillation: {std_idx}")
|
||||
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||
|
||||
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
|
||||
print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
|
||||
torch.save(compressed_sd, args.dump_checkpoint)
|
||||
|
||||
@@ -20,32 +20,36 @@ import argparse
|
||||
import pickle
|
||||
import logging
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)")
|
||||
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle",
|
||||
help="The binarized dataset.")
|
||||
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle",
|
||||
help="The dump file.")
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
|
||||
)
|
||||
parser.add_argument("--vocab_size", default=30522, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f'Loading data from {args.data_file}')
|
||||
with open(args.data_file, 'rb') as fp:
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
|
||||
logger.info('Counting occurences for MLM.')
|
||||
logger.info("Counting occurences for MLM.")
|
||||
counter = Counter()
|
||||
for tk_ids in data:
|
||||
counter.update(tk_ids)
|
||||
counts = [0]*args.vocab_size
|
||||
counts = [0] * args.vocab_size
|
||||
for k, v in counter.items():
|
||||
counts[k] = v
|
||||
|
||||
logger.info(f'Dump to {args.token_counts_dump}')
|
||||
with open(args.token_counts_dump, 'wb') as handle:
|
||||
logger.info(f"Dump to {args.token_counts_dump}")
|
||||
with open(args.token_counts_dump, "wb") as handle:
|
||||
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@@ -35,166 +35,200 @@ from lm_seqs_dataset import LmSeqsDataset
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
|
||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
}
|
||||
|
||||
|
||||
def sanity_checks(args):
|
||||
"""
|
||||
A bunch of args sanity checks to perform even starting...
|
||||
"""
|
||||
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.)
|
||||
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.)
|
||||
assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
|
||||
assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
|
||||
if args.mlm:
|
||||
assert os.path.isfile(args.token_counts)
|
||||
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert'])
|
||||
assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
|
||||
else:
|
||||
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2'])
|
||||
assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
|
||||
|
||||
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert')
|
||||
assert args.teacher_type == args.student_type or (
|
||||
args.student_type == "distilbert" and args.teacher_type == "bert"
|
||||
)
|
||||
assert os.path.isfile(args.student_config)
|
||||
if args.student_pretrained_weights is not None:
|
||||
assert os.path.isfile(args.student_pretrained_weights)
|
||||
|
||||
if args.freeze_token_type_embds: assert args.student_type in ['roberta']
|
||||
if args.freeze_token_type_embds:
|
||||
assert args.student_type in ["roberta"]
|
||||
|
||||
assert args.alpha_ce >= 0.0
|
||||
assert args.alpha_mlm >= 0.0
|
||||
assert args.alpha_clm >= 0.0
|
||||
assert args.alpha_mse >= 0.0
|
||||
assert args.alpha_cos >= 0.0
|
||||
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
|
||||
|
||||
assert args.alpha_ce >= 0.
|
||||
assert args.alpha_mlm >= 0.
|
||||
assert args.alpha_clm >= 0.
|
||||
assert args.alpha_mse >= 0.
|
||||
assert args.alpha_cos >= 0.
|
||||
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
|
||||
|
||||
def freeze_pos_embeddings(student, args):
|
||||
if args.student_type == 'roberta':
|
||||
if args.student_type == "roberta":
|
||||
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
||||
elif args.student_type == 'gpt2':
|
||||
elif args.student_type == "gpt2":
|
||||
student.transformer.wpe.weight.requires_grad = False
|
||||
|
||||
|
||||
def freeze_token_type_embeddings(student, args):
|
||||
if args.student_type == 'roberta':
|
||||
if args.student_type == "roberta":
|
||||
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training")
|
||||
parser.add_argument("--force", action='store_true',
|
||||
help="Overwrite dump_path if it already exists.")
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
|
||||
|
||||
parser.add_argument("--dump_path", type=str, required=True,
|
||||
help="The output directory (log, checkpoints, parameters, etc.)")
|
||||
parser.add_argument("--data_file", type=str, required=True,
|
||||
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
|
||||
parser.add_argument(
|
||||
"--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
|
||||
)
|
||||
|
||||
parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True,
|
||||
help="The student type (DistilBERT, RoBERTa).")
|
||||
parser.add_argument("--student_config", type=str, required=True,
|
||||
help="Path to the student configuration.")
|
||||
parser.add_argument("--student_pretrained_weights", default=None, type=str,
|
||||
help="Load student initialization checkpoint.")
|
||||
parser.add_argument(
|
||||
"--student_type",
|
||||
type=str,
|
||||
choices=["distilbert", "roberta", "gpt2"],
|
||||
required=True,
|
||||
help="The student type (DistilBERT, RoBERTa).",
|
||||
)
|
||||
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
|
||||
parser.add_argument(
|
||||
"--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True,
|
||||
help="Teacher type (BERT, RoBERTa).")
|
||||
parser.add_argument("--teacher_name", type=str, required=True,
|
||||
help="The teacher model.")
|
||||
parser.add_argument(
|
||||
"--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
|
||||
)
|
||||
parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
|
||||
|
||||
parser.add_argument("--temperature", default=2., type=float,
|
||||
help="Temperature for the softmax temperature.")
|
||||
parser.add_argument("--alpha_ce", default=0.5, type=float,
|
||||
help="Linear weight for the distillation loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_mlm", default=0.0, type=float,
|
||||
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.")
|
||||
parser.add_argument("--alpha_clm", default=0.5, type=float,
|
||||
help="Linear weight for the CLM loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_mse", default=0.0, type=float,
|
||||
help="Linear weight of the MSE loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_cos", default=0.0, type=float,
|
||||
help="Linear weight of the cosine embedding loss. Must be >=0.")
|
||||
parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
|
||||
parser.add_argument(
|
||||
"--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_mlm",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
|
||||
)
|
||||
parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
|
||||
parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
|
||||
parser.add_argument(
|
||||
"--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
|
||||
)
|
||||
|
||||
parser.add_argument("--mlm", action="store_true",
|
||||
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
|
||||
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
|
||||
help="Proportion of tokens for which we need to make a prediction.")
|
||||
parser.add_argument("--word_mask", default=0.8, type=float,
|
||||
help="Proportion of tokens to mask out.")
|
||||
parser.add_argument("--word_keep", default=0.1, type=float,
|
||||
help="Proportion of tokens to keep.")
|
||||
parser.add_argument("--word_rand", default=0.1, type=float,
|
||||
help="Proportion of tokens to randomly replace.")
|
||||
parser.add_argument("--mlm_smoothing", default=0.7, type=float,
|
||||
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
|
||||
parser.add_argument("--token_counts", type=str,
|
||||
help="The token counts in the data_file for MLM.")
|
||||
parser.add_argument(
|
||||
"--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlm_mask_prop",
|
||||
default=0.15,
|
||||
type=float,
|
||||
help="Proportion of tokens for which we need to make a prediction.",
|
||||
)
|
||||
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
|
||||
parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
|
||||
parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
|
||||
parser.add_argument(
|
||||
"--mlm_smoothing",
|
||||
default=0.7,
|
||||
type=float,
|
||||
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
|
||||
)
|
||||
parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
|
||||
|
||||
parser.add_argument("--restrict_ce_to_mask", action='store_true',
|
||||
help="If true, compute the distilation loss only the [MLM] prediction distribution.")
|
||||
parser.add_argument("--freeze_pos_embs", action="store_true",
|
||||
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.")
|
||||
parser.add_argument("--freeze_token_type_embds", action="store_true",
|
||||
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.")
|
||||
parser.add_argument(
|
||||
"--restrict_ce_to_mask",
|
||||
action="store_true",
|
||||
help="If true, compute the distilation loss only the [MLM] prediction distribution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_pos_embs",
|
||||
action="store_true",
|
||||
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_token_type_embds",
|
||||
action="store_true",
|
||||
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
|
||||
)
|
||||
|
||||
parser.add_argument("--n_epoch", type=int, default=3,
|
||||
help="Number of pass on the whole dataset.")
|
||||
parser.add_argument("--batch_size", type=int, default=5,
|
||||
help="Batch size (for each process).")
|
||||
parser.add_argument("--group_by_size", action='store_false',
|
||||
help="If true, group sequences that have similar length into the same batch. Default is true.")
|
||||
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
|
||||
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
|
||||
parser.add_argument(
|
||||
"--group_by_size",
|
||||
action="store_false",
|
||||
help="If true, group sequences that have similar length into the same batch. Default is true.",
|
||||
)
|
||||
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=50,
|
||||
help="Gradient accumulation for larger training batches.")
|
||||
parser.add_argument("--warmup_prop", default=0.05, type=float,
|
||||
help="Linear warmup proportion.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--learning_rate", default=5e-4, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-6, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=5.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--initializer_range", default=0.02, type=float,
|
||||
help="Random initialization range.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Gradient accumulation for larger training batches.",
|
||||
)
|
||||
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--n_gpu", type=int, default=1,
|
||||
help="Number of GPUs in the node.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="Distributed training - Local rank")
|
||||
parser.add_argument("--seed", type=int, default=56,
|
||||
help="Random seed")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
||||
parser.add_argument("--seed", type=int, default=56, help="Random seed")
|
||||
|
||||
parser.add_argument("--log_interval", type=int, default=500,
|
||||
help="Tensorboard logging interval.")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=4000,
|
||||
help="Checkpoint interval.")
|
||||
parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
|
||||
args = parser.parse_args()
|
||||
sanity_checks(args)
|
||||
|
||||
|
||||
## ARGS ##
|
||||
init_gpu_params(args)
|
||||
set_seed(args)
|
||||
if args.is_master:
|
||||
if os.path.exists(args.dump_path):
|
||||
if not args.force:
|
||||
raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it'
|
||||
'Use `--force` if you want to overwrite it')
|
||||
raise ValueError(
|
||||
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
|
||||
"Use `--force` if you want to overwrite it"
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(args.dump_path)
|
||||
|
||||
if not os.path.exists(args.dump_path):
|
||||
os.makedirs(args.dump_path)
|
||||
logger.info(f'Experiment will be dumped and logged in {args.dump_path}')
|
||||
|
||||
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
||||
|
||||
### SAVE PARAMS ###
|
||||
logger.info(f'Param: {args}')
|
||||
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
|
||||
logger.info(f"Param: {args}")
|
||||
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
git_log(args.dump_path)
|
||||
|
||||
@@ -207,58 +241,50 @@ def main():
|
||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
||||
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
||||
logger.info(f'Special tokens {special_tok_ids}')
|
||||
logger.info(f"Special tokens {special_tok_ids}")
|
||||
args.special_tok_ids = special_tok_ids
|
||||
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||
|
||||
|
||||
## DATA LOADER ##
|
||||
logger.info(f'Loading data from {args.data_file}')
|
||||
with open(args.data_file, 'rb') as fp:
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
|
||||
|
||||
if args.mlm:
|
||||
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
|
||||
with open(args.token_counts, 'rb') as fp:
|
||||
logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
|
||||
with open(args.token_counts, "rb") as fp:
|
||||
counts = pickle.load(fp)
|
||||
|
||||
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
||||
for idx in special_tok_ids.values():
|
||||
token_probs[idx] = 0. # do not predict special tokens
|
||||
token_probs[idx] = 0.0 # do not predict special tokens
|
||||
token_probs = torch.from_numpy(token_probs)
|
||||
else:
|
||||
token_probs = None
|
||||
|
||||
|
||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||
logger.info(f'Data loader created.')
|
||||
|
||||
logger.info(f"Data loader created.")
|
||||
|
||||
## STUDENT ##
|
||||
logger.info(f'Loading student config from {args.student_config}')
|
||||
logger.info(f"Loading student config from {args.student_config}")
|
||||
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||
stu_architecture_config.output_hidden_states = True
|
||||
|
||||
if args.student_pretrained_weights is not None:
|
||||
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
|
||||
student = student_model_class.from_pretrained(args.student_pretrained_weights,
|
||||
config=stu_architecture_config)
|
||||
logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
|
||||
student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
|
||||
else:
|
||||
student = student_model_class(stu_architecture_config)
|
||||
|
||||
|
||||
if args.n_gpu > 0:
|
||||
student.to(f'cuda:{args.local_rank}')
|
||||
logger.info(f'Student loaded.')
|
||||
|
||||
student.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Student loaded.")
|
||||
|
||||
## TEACHER ##
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
if args.n_gpu > 0:
|
||||
teacher.to(f'cuda:{args.local_rank}')
|
||||
logger.info(f'Teacher loaded from {args.teacher_name}.')
|
||||
|
||||
teacher.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
||||
|
||||
## FREEZING ##
|
||||
if args.freeze_pos_embs:
|
||||
@@ -266,7 +292,6 @@ def main():
|
||||
if args.freeze_token_type_embds:
|
||||
freeze_token_type_embeddings(student, args)
|
||||
|
||||
|
||||
## SANITY CHECKS ##
|
||||
assert student.config.vocab_size == teacher.config.vocab_size
|
||||
assert student.config.hidden_size == teacher.config.hidden_size
|
||||
@@ -274,14 +299,11 @@ def main():
|
||||
if args.mlm:
|
||||
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||
|
||||
|
||||
## DISTILLER ##
|
||||
torch.cuda.empty_cache()
|
||||
distiller = Distiller(params=args,
|
||||
dataset=train_lm_seq_dataset,
|
||||
token_probs=token_probs,
|
||||
student=student,
|
||||
teacher=teacher)
|
||||
distiller = Distiller(
|
||||
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
||||
)
|
||||
distiller.train()
|
||||
logger.info("Let's go get some drinks.")
|
||||
|
||||
|
||||
@@ -23,9 +23,12 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,12 +38,12 @@ def git_log(folder_path: str):
|
||||
"""
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
'repo_id': str(repo),
|
||||
'repo_sha': str(repo.head.object.hexsha),
|
||||
'repo_branch': str(repo.active_branch)
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
}
|
||||
|
||||
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f:
|
||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
|
||||
@@ -57,21 +60,21 @@ def init_gpu_params(params):
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
logger.info('Initializing GPUs')
|
||||
logger.info("Initializing GPUs")
|
||||
if params.n_gpu > 1:
|
||||
assert params.local_rank != -1
|
||||
|
||||
params.world_size = int(os.environ['WORLD_SIZE'])
|
||||
params.n_gpu_per_node = int(os.environ['N_GPU_NODE'])
|
||||
params.global_rank = int(os.environ['RANK'])
|
||||
params.world_size = int(os.environ["WORLD_SIZE"])
|
||||
params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
|
||||
params.global_rank = int(os.environ["RANK"])
|
||||
|
||||
# number of nodes / node ID
|
||||
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||
params.multi_gpu = True
|
||||
|
||||
assert params.n_nodes == int(os.environ['N_NODES'])
|
||||
assert params.node_id == int(os.environ['NODE_RANK'])
|
||||
assert params.n_nodes == int(os.environ["N_NODES"])
|
||||
assert params.node_id == int(os.environ["NODE_RANK"])
|
||||
|
||||
# local job (single GPU)
|
||||
else:
|
||||
@@ -114,8 +117,7 @@ def init_gpu_params(params):
|
||||
if params.multi_gpu:
|
||||
logger.info("Initializing PyTorch distributed")
|
||||
torch.distributed.init_process_group(
|
||||
init_method='env://',
|
||||
backend='nccl',
|
||||
init_method="env://", backend="nccl",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -40,29 +40,49 @@ from tqdm import tqdm, trange
|
||||
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
|
||||
|
||||
from transformers import (WEIGHTS_NAME,
|
||||
BertConfig, BertModel, BertTokenizer,
|
||||
RobertaConfig, RobertaModel, RobertaTokenizer,
|
||||
XLMConfig, XLMModel, XLMTokenizer,
|
||||
XLNetConfig, XLNetModel, XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertModel, DistilBertTokenizer,
|
||||
AlbertConfig, AlbertModel, AlbertTokenizer,
|
||||
MMBTForClassification, MMBTConfig)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaModel,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMModel,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetModel,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertModel,
|
||||
DistilBertTokenizer,
|
||||
AlbertConfig,
|
||||
AlbertModel,
|
||||
AlbertTokenizer,
|
||||
MMBTForClassification,
|
||||
MMBTConfig,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
|
||||
RobertaConfig, DistilBertConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertModel, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMModel, XLMTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||
'albert': (AlbertConfig, AlbertModel, AlbertTokenizer)
|
||||
"bert": (BertConfig, BertModel, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@@ -81,10 +101,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.num_workers)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
@@ -93,14 +116,19 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
@@ -114,17 +142,21 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -140,17 +172,19 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
labels = batch[5]
|
||||
inputs = {'input_ids': batch[0],
|
||||
'input_modal': batch[2],
|
||||
'attention_mask': batch[1],
|
||||
'modal_start_tokens': batch[3],
|
||||
'modal_end_tokens': batch[4]}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"input_modal": batch[2],
|
||||
"attention_mask": batch[1],
|
||||
"modal_start_tokens": batch[3],
|
||||
"modal_end_tokens": batch[4],
|
||||
}
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -174,30 +208,34 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer, criterion)
|
||||
for key, value in results.items():
|
||||
eval_key = 'eval_{}'.format(key)
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs['learning_rate'] = learning_rate_scalar
|
||||
logs['loss'] = loss_scalar
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{'step': global_step}}))
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -209,13 +247,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
||||
|
||||
if args.local_rank == -1:
|
||||
results = evaluate(args, model, tokenizer, criterion)
|
||||
if results['micro_f1'] > best_f1:
|
||||
best_f1 = results['micro_f1']
|
||||
if results["micro_f1"] > best_f1:
|
||||
best_f1 = results["micro_f1"]
|
||||
n_no_improve = 0
|
||||
else:
|
||||
n_no_improve += 1
|
||||
|
||||
if n_no_improve > args.patience:
|
||||
if n_no_improve > args.patience:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
@@ -236,7 +274,9 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
@@ -257,11 +297,13 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
||||
with torch.no_grad():
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
labels = batch[5]
|
||||
inputs = {'input_ids': batch[0],
|
||||
'input_modal': batch[2],
|
||||
'attention_mask': batch[1],
|
||||
'modal_start_tokens': batch[3],
|
||||
'modal_end_tokens': batch[4]}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"input_modal": batch[2],
|
||||
"attention_mask": batch[1],
|
||||
"modal_start_tokens": batch[3],
|
||||
"modal_end_tokens": batch[4],
|
||||
}
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
tmp_eval_loss = criterion(logits, labels)
|
||||
@@ -278,7 +320,7 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
||||
result = {
|
||||
"loss": eval_loss,
|
||||
"macro_f1": f1_score(out_label_ids, preds, average="macro"),
|
||||
"micro_f1": f1_score(out_label_ids, preds, average="micro")
|
||||
"micro_f1": f1_score(out_label_ids, preds, average="micro"),
|
||||
}
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
@@ -303,94 +345,147 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--num_image_embeds", default=1, type=int,
|
||||
help="Number of Image Embeddings from the Image Encoder")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--patience", default=5, type=int,
|
||||
help="Patience for Early Stopping.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--num_workers', type=int, default=8,
|
||||
help="number of worker threads for dataloading")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -402,17 +497,25 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -426,13 +529,17 @@ def main():
|
||||
num_labels = len(labels)
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
transformer_config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
transformer = model_class.from_pretrained(args.model_name_or_path,
|
||||
config=transformer_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
transformer_config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
transformer = model_class.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
)
|
||||
img_encoder = ImageEncoder(args)
|
||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
@@ -449,12 +556,13 @@ def main():
|
||||
train_dataset = load_examples(args, tokenizer, evaluate=False)
|
||||
label_frequences = train_dataset.get_label_frequencies()
|
||||
label_frequences = [label_frequences[l] for l in labels]
|
||||
label_weights = (torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)) ** -1
|
||||
label_weights = (
|
||||
torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)
|
||||
) ** -1
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -464,12 +572,14 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
@@ -477,24 +587,25 @@ def main():
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
@@ -25,17 +25,7 @@ import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
POOLING_BREAKDOWN = {
|
||||
1: (1, 1),
|
||||
2: (2, 1),
|
||||
3: (3, 1),
|
||||
4: (2, 2),
|
||||
5: (5, 1),
|
||||
6: (3, 2),
|
||||
7: (7, 1),
|
||||
8: (4, 2),
|
||||
9: (3, 3)
|
||||
}
|
||||
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module):
|
||||
return out # BxNx2048
|
||||
|
||||
|
||||
|
||||
class JsonlDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
||||
self.data = [json.loads(l) for l in open(data_path)]
|
||||
@@ -72,7 +61,7 @@ class JsonlDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
|
||||
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
|
||||
sentence = sentence[:self.max_seq_length]
|
||||
sentence = sentence[: self.max_seq_length]
|
||||
|
||||
label = torch.zeros(self.n_classes)
|
||||
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
|
||||
@@ -80,8 +69,13 @@ class JsonlDataset(Dataset):
|
||||
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
||||
image = self.transforms(image)
|
||||
|
||||
return {"image_start_token": start_token, "image_end_token": end_token,
|
||||
"sentence": sentence, "image": image, "label": label}
|
||||
return {
|
||||
"image_start_token": start_token,
|
||||
"image_end_token": end_token,
|
||||
"sentence": sentence,
|
||||
"image": image,
|
||||
"label": label,
|
||||
}
|
||||
|
||||
def get_label_frequencies(self):
|
||||
label_freqs = Counter()
|
||||
@@ -110,10 +104,31 @@ def collate_fn(batch):
|
||||
|
||||
|
||||
def get_mmimdb_labels():
|
||||
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance',
|
||||
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure',
|
||||
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music',
|
||||
'Musical', 'Animation', 'Biography', 'Film-Noir']
|
||||
return [
|
||||
"Crime",
|
||||
"Drama",
|
||||
"Thriller",
|
||||
"Action",
|
||||
"Comedy",
|
||||
"Romance",
|
||||
"Documentary",
|
||||
"Short",
|
||||
"Mystery",
|
||||
"History",
|
||||
"Family",
|
||||
"Adventure",
|
||||
"Fantasy",
|
||||
"Sci-Fi",
|
||||
"Western",
|
||||
"Horror",
|
||||
"Sport",
|
||||
"War",
|
||||
"Music",
|
||||
"Musical",
|
||||
"Animation",
|
||||
"Biography",
|
||||
"Film-Noir",
|
||||
]
|
||||
|
||||
|
||||
def get_image_transforms():
|
||||
@@ -122,9 +137,6 @@ def get_image_transforms():
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.46777044, 0.44531429, 0.40661017],
|
||||
std=[0.12221994, 0.12145835, 0.14380469],
|
||||
),
|
||||
transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ClassificationHead(torch.nn.Module):
|
||||
"""Classification Head for transformer encoders"""
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
#! /usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
#Copyright (c) 2019 Uber Technologies, Inc.
|
||||
# Copyright (c) 2019 Uber Technologies, Inc.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
#http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command with bag of words:
|
||||
@@ -46,13 +46,13 @@ SMALL_CONST = 1e-15
|
||||
BIG_CONST = 1e10
|
||||
|
||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
||||
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
||||
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
||||
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
||||
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
||||
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
||||
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
||||
"legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
||||
"military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
||||
"politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
||||
"religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
||||
"science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
||||
"space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
||||
"technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
||||
}
|
||||
|
||||
DISCRIMINATOR_MODELS_PARAMS = {
|
||||
@@ -75,10 +75,10 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
||||
}
|
||||
|
||||
|
||||
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
|
||||
if torch.cuda.is_available() and device == 'cuda':
|
||||
def to_var(x, requires_grad=False, volatile=False, device="cuda"):
|
||||
if torch.cuda.is_available() and device == "cuda":
|
||||
x = x.cuda()
|
||||
elif device != 'cuda':
|
||||
elif device != "cuda":
|
||||
x = x.to(device)
|
||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
||||
|
||||
@@ -95,49 +95,39 @@ def top_k_filter(logits, k, probs=False):
|
||||
values = torch.topk(logits, k)[0]
|
||||
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
||||
if probs:
|
||||
return torch.where(logits < batch_mins,
|
||||
torch.ones_like(logits) * 0.0, logits)
|
||||
return torch.where(logits < batch_mins,
|
||||
torch.ones_like(logits) * -BIG_CONST,
|
||||
logits)
|
||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits)
|
||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits)
|
||||
|
||||
|
||||
def perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=None,
|
||||
unpert_logits=None,
|
||||
accumulated_hidden=None,
|
||||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
one_hot_bows_vectors=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
kl_scale=0.01,
|
||||
device='cuda',
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=None,
|
||||
unpert_logits=None,
|
||||
accumulated_hidden=None,
|
||||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
one_hot_bows_vectors=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
kl_scale=0.01,
|
||||
device="cuda",
|
||||
):
|
||||
# Generate inital perturbed past
|
||||
grad_accumulator = [
|
||||
(np.zeros(p.shape).astype("float32"))
|
||||
for p in past
|
||||
]
|
||||
grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]
|
||||
|
||||
if accumulated_hidden is None:
|
||||
accumulated_hidden = 0
|
||||
|
||||
if decay:
|
||||
decay_mask = torch.arange(
|
||||
0.,
|
||||
1.0 + SMALL_CONST,
|
||||
1.0 / (window_length)
|
||||
)[1:]
|
||||
decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
|
||||
else:
|
||||
decay_mask = 1.0
|
||||
|
||||
@@ -146,26 +136,17 @@ def perturb_past(
|
||||
_, _, _, curr_length, _ = past[0].shape
|
||||
|
||||
if curr_length > window_length and window_length > 0:
|
||||
ones_key_val_shape = (
|
||||
tuple(past[0].shape[:-2])
|
||||
+ tuple([window_length])
|
||||
+ tuple(past[0].shape[-1:])
|
||||
)
|
||||
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
|
||||
|
||||
zeros_key_val_shape = (
|
||||
tuple(past[0].shape[:-2])
|
||||
+ tuple([curr_length - window_length])
|
||||
+ tuple(past[0].shape[-1:])
|
||||
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
|
||||
)
|
||||
|
||||
ones_mask = torch.ones(ones_key_val_shape)
|
||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
||||
|
||||
window_mask = torch.cat(
|
||||
(ones_mask, torch.zeros(zeros_key_val_shape)),
|
||||
dim=-2
|
||||
).to(device)
|
||||
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
|
||||
else:
|
||||
window_mask = torch.ones_like(past[0]).to(device)
|
||||
|
||||
@@ -175,8 +156,7 @@ def perturb_past(
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
curr_perturbation = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||
for p_ in grad_accumulator
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
|
||||
]
|
||||
|
||||
# Compute hidden using perturbed past
|
||||
@@ -184,10 +164,7 @@ def perturb_past(
|
||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
||||
hidden = all_hidden[-1]
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(
|
||||
hidden,
|
||||
dim=1
|
||||
).detach()
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
logits = all_logits[:, -1, :]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
@@ -210,20 +187,13 @@ def perturb_past(
|
||||
wte = model.resize_token_embeddings()
|
||||
for _ in range(horizon_length):
|
||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||
_, curr_unpert_past, curr_all_hidden = model(
|
||||
past=curr_unpert_past,
|
||||
inputs_embeds=inputs_embeds
|
||||
)
|
||||
_, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||
curr_hidden = curr_all_hidden[-1]
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
||||
curr_hidden, dim=1)
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||
|
||||
prediction = classifier(new_accumulated_hidden /
|
||||
(curr_length + 1 + horizon_length))
|
||||
prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor(prediction.shape[0] * [class_label],
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long)
|
||||
discrim_loss = ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
loss += discrim_loss
|
||||
@@ -232,21 +202,15 @@ def perturb_past(
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = (
|
||||
unpert_probs + SMALL_CONST *
|
||||
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
)
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
|
||||
device).detach()
|
||||
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
kl_loss = kl_scale * (
|
||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||
)
|
||||
print(' kl_loss', kl_loss.data.cpu().numpy())
|
||||
kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum())
|
||||
print(" kl_loss", kl_loss.data.cpu().numpy())
|
||||
loss += kl_loss
|
||||
|
||||
loss_per_iter.append(loss.data.cpu().numpy())
|
||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||
print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())
|
||||
|
||||
# compute gradients
|
||||
loss.backward()
|
||||
@@ -259,15 +223,12 @@ def perturb_past(
|
||||
]
|
||||
else:
|
||||
grad_norms = [
|
||||
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
(torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
# normalize gradients
|
||||
grad = [
|
||||
-stepsize *
|
||||
(p_.grad * window_mask / grad_norms[
|
||||
index] ** gamma).data.cpu().numpy()
|
||||
-stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
@@ -285,36 +246,27 @@ def perturb_past(
|
||||
past = new_past
|
||||
|
||||
# apply the accumulated perturbations to the past
|
||||
grad_accumulator = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator]
|
||||
pert_past = list(map(add, past, grad_accumulator))
|
||||
|
||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||
|
||||
|
||||
def get_classifier(
|
||||
name: Optional[str], class_label: Union[str, int],
|
||||
device: str
|
||||
name: Optional[str], class_label: Union[str, int], device: str
|
||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||
if name is None:
|
||||
return None, None
|
||||
|
||||
params = DISCRIMINATOR_MODELS_PARAMS[name]
|
||||
classifier = ClassificationHead(
|
||||
class_size=params['class_size'],
|
||||
embed_size=params['embed_size']
|
||||
).to(device)
|
||||
classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device)
|
||||
if "url" in params:
|
||||
resolved_archive_file = cached_path(params["url"])
|
||||
elif "path" in params:
|
||||
resolved_archive_file = params["path"]
|
||||
else:
|
||||
raise ValueError("Either url or path have to be specified "
|
||||
"in the discriminator model parameters")
|
||||
classifier.load_state_dict(
|
||||
torch.load(resolved_archive_file, map_location=device))
|
||||
raise ValueError("Either url or path have to be specified " "in the discriminator model parameters")
|
||||
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
|
||||
classifier.eval()
|
||||
|
||||
if isinstance(class_label, str):
|
||||
@@ -341,8 +293,7 @@ def get_classifier(
|
||||
return classifier, label_id
|
||||
|
||||
|
||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
|
||||
List[List[List[int]]]:
|
||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]:
|
||||
bow_indices = []
|
||||
for id_or_path in bag_of_words_ids_or_paths:
|
||||
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
||||
@@ -351,13 +302,11 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) ->
|
||||
filepath = id_or_path
|
||||
with open(filepath, "r") as f:
|
||||
words = f.read().strip().split("\n")
|
||||
bow_indices.append(
|
||||
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
|
||||
words])
|
||||
bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words])
|
||||
return bow_indices
|
||||
|
||||
|
||||
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
|
||||
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
|
||||
if bow_indices is None:
|
||||
return None
|
||||
|
||||
@@ -373,39 +322,34 @@ def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
|
||||
|
||||
|
||||
def full_text_generation(
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
class_label=None,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
**kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
class_label=None,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(
|
||||
discrim,
|
||||
class_label,
|
||||
device
|
||||
)
|
||||
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||
|
||||
bow_indices = []
|
||||
if bag_of_words:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
||||
tokenizer)
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
|
||||
|
||||
if bag_of_words and classifier:
|
||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||
@@ -423,15 +367,9 @@ def full_text_generation(
|
||||
raise Exception("Specify either a bag of words or a discriminator")
|
||||
|
||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
length=length,
|
||||
sample=sample,
|
||||
perturb=False
|
||||
model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
|
||||
)
|
||||
if device == 'cuda':
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pert_gen_tok_texts = []
|
||||
@@ -468,36 +406,36 @@ def full_text_generation(
|
||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||
losses_in_time.append(loss_in_time)
|
||||
|
||||
if device == 'cuda':
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
|
||||
|
||||
def generate_text_pplm(
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
past=None,
|
||||
device="cuda",
|
||||
perturb=True,
|
||||
bow_indices=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
past=None,
|
||||
device="cuda",
|
||||
perturb=True,
|
||||
bow_indices=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
):
|
||||
output_so_far = None
|
||||
if context:
|
||||
@@ -507,8 +445,7 @@ def generate_text_pplm(
|
||||
output_so_far = context_t
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
||||
device)
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)
|
||||
|
||||
grad_norms = None
|
||||
last = None
|
||||
@@ -575,13 +512,9 @@ def generate_text_pplm(
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([class_label], device=device,
|
||||
dtype=torch.long)
|
||||
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
print(
|
||||
"unperturbed discrim loss",
|
||||
unpert_discrim_loss.data.cpu().numpy()
|
||||
)
|
||||
print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
|
||||
else:
|
||||
unpert_discrim_loss = 0
|
||||
|
||||
@@ -590,10 +523,8 @@ def generate_text_pplm(
|
||||
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = ((pert_probs ** gm_scale) * (
|
||||
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k,
|
||||
probs=True) # + SMALL_CONST
|
||||
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||
|
||||
# rescale
|
||||
if torch.sum(pert_probs) <= 1:
|
||||
@@ -611,10 +542,7 @@ def generate_text_pplm(
|
||||
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
||||
|
||||
# update context/output_so_far appending the new token
|
||||
output_so_far = (
|
||||
last if output_so_far is None
|
||||
else torch.cat((output_so_far, last), dim=1)
|
||||
)
|
||||
output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
|
||||
|
||||
print(tokenizer.decode(output_so_far.tolist()[0]))
|
||||
|
||||
@@ -623,44 +551,42 @@ def generate_text_pplm(
|
||||
|
||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||
if discrim_weights is None:
|
||||
raise ValueError('When using a generic discriminator, '
|
||||
'discrim_weights need to be specified')
|
||||
raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified")
|
||||
if discrim_meta is None:
|
||||
raise ValueError('When using a generic discriminator, '
|
||||
'discrim_meta need to be specified')
|
||||
raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified")
|
||||
|
||||
with open(discrim_meta, 'r') as discrim_meta_file:
|
||||
with open(discrim_meta, "r") as discrim_meta_file:
|
||||
meta = json.load(discrim_meta_file)
|
||||
meta['path'] = discrim_weights
|
||||
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
||||
meta["path"] = discrim_weights
|
||||
DISCRIMINATOR_MODELS_PARAMS["generic"] = meta
|
||||
|
||||
|
||||
def run_pplm_example(
|
||||
pretrained_model="gpt2-medium",
|
||||
cond_text="",
|
||||
uncond=False,
|
||||
num_samples=1,
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
discrim_weights=None,
|
||||
discrim_meta=None,
|
||||
class_label=-1,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
seed=0,
|
||||
no_cuda=False,
|
||||
colorama=False
|
||||
pretrained_model="gpt2-medium",
|
||||
cond_text="",
|
||||
uncond=False,
|
||||
num_samples=1,
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
discrim_weights=None,
|
||||
discrim_meta=None,
|
||||
class_label=-1,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
seed=0,
|
||||
no_cuda=False,
|
||||
colorama=False,
|
||||
):
|
||||
# set Random seed
|
||||
torch.manual_seed(seed)
|
||||
@@ -669,21 +595,15 @@ def run_pplm_example(
|
||||
# set the device
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
if discrim == 'generic':
|
||||
if discrim == "generic":
|
||||
set_generic_model_params(discrim_weights, discrim_meta)
|
||||
|
||||
if discrim is not None:
|
||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
|
||||
"pretrained_model"
|
||||
]
|
||||
print("discrim = {}, pretrained_model set "
|
||||
"to discriminator's = {}".format(discrim, pretrained_model))
|
||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
|
||||
print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model))
|
||||
|
||||
# load pretrained model
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
pretrained_model,
|
||||
output_hidden_states=True
|
||||
)
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
@@ -696,9 +616,7 @@ def run_pplm_example(
|
||||
|
||||
# figure out conditioning text
|
||||
if uncond:
|
||||
tokenized_cond_text = tokenizer.encode(
|
||||
[tokenizer.bos_token]
|
||||
)
|
||||
tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
|
||||
else:
|
||||
raw_text = cond_text
|
||||
while not raw_text:
|
||||
@@ -750,8 +668,7 @@ def run_pplm_example(
|
||||
|
||||
bow_word_ids = set()
|
||||
if bag_of_words and colorama:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
||||
tokenizer)
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
|
||||
for single_bow_list in bow_indices:
|
||||
# filtering all words in the list composed of more than 1 token
|
||||
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||
@@ -765,13 +682,11 @@ def run_pplm_example(
|
||||
if colorama:
|
||||
import colorama
|
||||
|
||||
pert_gen_text = ''
|
||||
pert_gen_text = ""
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += '{}{}{}'.format(
|
||||
colorama.Fore.RED,
|
||||
tokenizer.decode([word_id]),
|
||||
colorama.Style.RESET_ALL
|
||||
pert_gen_text += "{}{}{}".format(
|
||||
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
pert_gen_text += tokenizer.decode([word_id])
|
||||
@@ -785,14 +700,12 @@ def run_pplm_example(
|
||||
pass
|
||||
|
||||
# keep the prefix, perturbed seq, original seq for each index
|
||||
generated_texts.append(
|
||||
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
|
||||
)
|
||||
generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model",
|
||||
@@ -801,19 +714,10 @@ if __name__ == '__main__':
|
||||
default="gpt2-medium",
|
||||
help="pretrained model name or path to local checkpoint",
|
||||
)
|
||||
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
|
||||
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
|
||||
parser.add_argument(
|
||||
"--cond_text", type=str, default="The lake",
|
||||
help="Prefix texts to condition on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uncond", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
"--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
@@ -821,8 +725,8 @@ if __name__ == '__main__':
|
||||
type=str,
|
||||
default=None,
|
||||
help="Bags of words used for PPLM-BoW. "
|
||||
"Either a BOW id (see list in code) or a filepath. "
|
||||
"Multiple BoWs separated by ;",
|
||||
"Either a BOW id (see list in code) or a filepath. "
|
||||
"Multiple BoWs separated by ;",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim",
|
||||
@@ -832,48 +736,36 @@ if __name__ == '__main__':
|
||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument('--discrim_weights', type=str, default=None,
|
||||
help='Weights for the generic discriminator')
|
||||
parser.add_argument('--discrim_meta', type=str, default=None,
|
||||
help='Meta information for the generic discriminator')
|
||||
parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
|
||||
parser.add_argument(
|
||||
"--class_label",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--sample", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix")
|
||||
parser.add_argument("--num_iterations", type=int, default=3)
|
||||
parser.add_argument("--grad_length", type=int, default=10000)
|
||||
parser.add_argument(
|
||||
"--window_length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of past which is being optimized; "
|
||||
"0 corresponds to infinite window length",
|
||||
help="Length of past which is being optimized; " "0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true",
|
||||
help="whether to decay or not")
|
||||
parser.add_argument("--decay", action="store_true", help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true",
|
||||
help="colors keywords")
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
|
||||
args = parser.parse_args()
|
||||
run_pplm_example(**vars(args))
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
#! /usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
#Copyright (c) 2019 Uber Technologies, Inc.
|
||||
# Copyright (c) 2019 Uber Technologies, Inc.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
#http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
@@ -42,26 +42,15 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
|
||||
max_length_seq = 100
|
||||
|
||||
|
||||
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
"""Transformer encoder followed by a Classification Head"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
class_size,
|
||||
pretrained_model="gpt2-medium",
|
||||
cached_mode=False,
|
||||
device='cpu'
|
||||
):
|
||||
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
|
||||
super(Discriminator, self).__init__()
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
||||
self.embed_size = self.encoder.transformer.config.hidden_size
|
||||
self.classifier_head = ClassificationHead(
|
||||
class_size=class_size,
|
||||
embed_size=self.embed_size
|
||||
)
|
||||
self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
|
||||
self.cached_mode = cached_mode
|
||||
self.device = device
|
||||
|
||||
@@ -74,14 +63,10 @@ class Discriminator(torch.nn.Module):
|
||||
self.classifier_head.train()
|
||||
|
||||
def avg_representation(self, x):
|
||||
mask = x.ne(0).unsqueeze(2).repeat(
|
||||
1, 1, self.embed_size
|
||||
).float().to(self.device).detach()
|
||||
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
|
||||
hidden, _ = self.encoder.transformer(x)
|
||||
masked_hidden = hidden * mask
|
||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
||||
torch.sum(mask, dim=1).detach() + EPSILON
|
||||
)
|
||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
|
||||
return avg_hidden
|
||||
|
||||
def forward(self, x):
|
||||
@@ -117,10 +102,7 @@ def collate_fn(data):
|
||||
def pad_sequences(sequences):
|
||||
lengths = [len(seq) for seq in sequences]
|
||||
|
||||
padded_sequences = torch.zeros(
|
||||
len(sequences),
|
||||
max(lengths)
|
||||
).long() # padding value = 0
|
||||
padded_sequences = torch.zeros(len(sequences), max(lengths)).long() # padding value = 0
|
||||
|
||||
for i, seq in enumerate(sequences):
|
||||
end = lengths[i]
|
||||
@@ -149,8 +131,7 @@ def cached_collate_fn(data):
|
||||
return x_batch, y_batch
|
||||
|
||||
|
||||
def train_epoch(data_loader, discriminator, optimizer,
|
||||
epoch=0, log_interval=10, device='cpu'):
|
||||
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
|
||||
samples_so_far = 0
|
||||
discriminator.train_custom()
|
||||
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||
@@ -169,13 +150,15 @@ def train_epoch(data_loader, discriminator, optimizer,
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch + 1,
|
||||
samples_so_far, len(data_loader.dataset),
|
||||
100 * samples_so_far / len(data_loader.dataset), loss.item()
|
||||
samples_so_far,
|
||||
len(data_loader.dataset),
|
||||
100 * samples_so_far / len(data_loader.dataset),
|
||||
loss.item(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def evaluate_performance(data_loader, discriminator, device='cpu'):
|
||||
def evaluate_performance(data_loader, discriminator, device="cpu"):
|
||||
discriminator.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
@@ -194,13 +177,12 @@ def evaluate_performance(data_loader, discriminator, device='cpu'):
|
||||
print(
|
||||
"Performance on test set: "
|
||||
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
||||
test_loss, correct, len(data_loader.dataset),
|
||||
100. * correct / len(data_loader.dataset)
|
||||
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
||||
def predict(input_sentence, model, classes, cached=False, device="cpu"):
|
||||
input_t = model.tokenizer.encode(input_sentence)
|
||||
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||
if cached:
|
||||
@@ -208,17 +190,14 @@ def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
||||
|
||||
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
||||
print("Input sentence:", input_sentence)
|
||||
print("Predictions:", ", ".join(
|
||||
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
|
||||
zip(classes, log_probs)
|
||||
))
|
||||
print(
|
||||
"Predictions:",
|
||||
", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
|
||||
)
|
||||
|
||||
|
||||
def get_cached_data_loader(dataset, batch_size, discriminator,
|
||||
shuffle=False, device='cpu'):
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn)
|
||||
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
|
||||
|
||||
xs = []
|
||||
ys = []
|
||||
@@ -231,50 +210,44 @@ def get_cached_data_loader(dataset, batch_size, discriminator,
|
||||
ys += y.cpu().numpy().tolist()
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset=Dataset(xs, ys),
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
collate_fn=cached_collate_fn)
|
||||
dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
|
||||
)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
def train_discriminator(
|
||||
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
||||
epochs=10, batch_size=64, log_interval=10,
|
||||
save_model=False, cached=False, no_cuda=False):
|
||||
dataset,
|
||||
dataset_fp=None,
|
||||
pretrained_model="gpt2-medium",
|
||||
epochs=10,
|
||||
batch_size=64,
|
||||
log_interval=10,
|
||||
save_model=False,
|
||||
cached=False,
|
||||
no_cuda=False,
|
||||
):
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
print("Preprocessing {} dataset...".format(dataset))
|
||||
start = time.time()
|
||||
|
||||
if dataset == "SST":
|
||||
idx2class = ["positive", "negative", "very positive", "very negative",
|
||||
"neutral"]
|
||||
idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
text = torchtext_data.Field()
|
||||
label = torchtext_data.Field(sequential=False)
|
||||
train_data, val_data, test_data = datasets.SST.splits(
|
||||
text,
|
||||
label,
|
||||
fine_grained=True,
|
||||
train_subtrees=True,
|
||||
)
|
||||
train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True,)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for i in trange(len(train_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(
|
||||
vars(train_data[i])["text"]
|
||||
)
|
||||
seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
|
||||
seq = discriminator.tokenizer.encode(seq)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
x.append(seq)
|
||||
@@ -284,9 +257,7 @@ def train_discriminator(
|
||||
test_x = []
|
||||
test_y = []
|
||||
for i in trange(len(test_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(
|
||||
vars(test_data[i])["text"]
|
||||
)
|
||||
seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
|
||||
seq = discriminator.tokenizer.encode(seq)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
test_x.append(seq)
|
||||
@@ -306,10 +277,7 @@ def train_discriminator(
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
@@ -318,9 +286,7 @@ def train_discriminator(
|
||||
try:
|
||||
data.append(eval(line))
|
||||
except:
|
||||
print("Error evaluating line {}: {}".format(
|
||||
i, line
|
||||
))
|
||||
print("Error evaluating line {}: {}".format(i, line))
|
||||
continue
|
||||
x = []
|
||||
y = []
|
||||
@@ -331,27 +297,20 @@ def train_discriminator(
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(d["label"])
|
||||
except:
|
||||
print("Error evaluating / tokenizing"
|
||||
" line {}, skipping it".format(i))
|
||||
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||
full_dataset, [train_size, test_size]
|
||||
)
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
@@ -366,10 +325,7 @@ def train_discriminator(
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
@@ -381,27 +337,20 @@ def train_discriminator(
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(int(np.sum(d["label"]) > 0))
|
||||
except:
|
||||
print("Error evaluating / tokenizing"
|
||||
" line {}, skipping it".format(i))
|
||||
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||
full_dataset, [train_size, test_size]
|
||||
)
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
@@ -416,8 +365,7 @@ def train_discriminator(
|
||||
# class \t text
|
||||
|
||||
if dataset_fp is None:
|
||||
raise ValueError("When generic dataset is selected, "
|
||||
"dataset_fp needs to be specified aswell.")
|
||||
raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
|
||||
|
||||
classes = set()
|
||||
with open(dataset_fp) as f:
|
||||
@@ -430,10 +378,7 @@ def train_discriminator(
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
@@ -447,18 +392,11 @@ def train_discriminator(
|
||||
|
||||
try:
|
||||
seq = discriminator.tokenizer.encode(text)
|
||||
if (len(seq) < max_length_seq):
|
||||
seq = torch.tensor(
|
||||
[50256] + seq,
|
||||
device=device,
|
||||
dtype=torch.long
|
||||
)
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
|
||||
else:
|
||||
print(
|
||||
"Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||
continue
|
||||
|
||||
x.append(seq)
|
||||
@@ -471,10 +409,7 @@ def train_discriminator(
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||
full_dataset,
|
||||
[train_size, test_size]
|
||||
)
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
@@ -485,9 +420,7 @@ def train_discriminator(
|
||||
}
|
||||
|
||||
end = time.time()
|
||||
print("Preprocessed {} data points".format(
|
||||
len(train_dataset) + len(test_dataset))
|
||||
)
|
||||
print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
|
||||
print("Data preprocessing took: {:.3f}s".format(end - start))
|
||||
|
||||
if cached:
|
||||
@@ -495,30 +428,21 @@ def train_discriminator(
|
||||
|
||||
start = time.time()
|
||||
|
||||
train_loader = get_cached_data_loader(
|
||||
train_dataset, batch_size, discriminator,
|
||||
shuffle=True, device=device
|
||||
)
|
||||
train_loader = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)
|
||||
|
||||
test_loader = get_cached_data_loader(
|
||||
test_dataset, batch_size, discriminator, device=device
|
||||
)
|
||||
test_loader = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)
|
||||
|
||||
end = time.time()
|
||||
print("Building representation cache took: {:.3f}s".format(end - start))
|
||||
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn)
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
|
||||
|
||||
if save_model:
|
||||
with open("{}_classifier_head_meta.json".format(dataset),
|
||||
"w") as meta_file:
|
||||
with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
|
||||
json.dump(discriminator_meta, meta_file)
|
||||
|
||||
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
||||
@@ -533,56 +457,61 @@ def train_discriminator(
|
||||
optimizer=optimizer,
|
||||
epoch=epoch,
|
||||
log_interval=log_interval,
|
||||
device=device
|
||||
)
|
||||
evaluate_performance(
|
||||
data_loader=test_loader,
|
||||
discriminator=discriminator,
|
||||
device=device
|
||||
device=device,
|
||||
)
|
||||
evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
|
||||
|
||||
end = time.time()
|
||||
print("Epoch took: {:.3f}s".format(end - start))
|
||||
|
||||
print("\nExample prediction")
|
||||
predict(example_sentence, discriminator, idx2class,
|
||||
cached=cached, device=device)
|
||||
predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
|
||||
|
||||
if save_model:
|
||||
# torch.save(discriminator.state_dict(),
|
||||
# "{}_discriminator_{}.pt".format(
|
||||
# args.dataset, epoch + 1
|
||||
# ))
|
||||
torch.save(discriminator.get_classifier().state_dict(),
|
||||
"{}_classifier_head_epoch_{}.pt".format(dataset,
|
||||
epoch + 1))
|
||||
torch.save(
|
||||
discriminator.get_classifier().state_dict(),
|
||||
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a discriminator on top of GPT-2 representations")
|
||||
parser.add_argument("--dataset", type=str, default="SST",
|
||||
choices=("SST", "clickbait", "toxic", "generic"),
|
||||
help="dataset to train the discriminator on."
|
||||
"In case of generic, the dataset is expected"
|
||||
"to be a TSBV file with structure: class \\t text")
|
||||
parser.add_argument("--dataset_fp", type=str, default="",
|
||||
help="File path of the dataset to use. "
|
||||
"Needed only in case of generic datadset")
|
||||
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
|
||||
help="Pretrained model to use as encoder")
|
||||
parser.add_argument("--epochs", type=int, default=10, metavar="N",
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--batch_size", type=int, default=64, metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument("--log_interval", type=int, default=10, metavar="N",
|
||||
help="how many batches to wait before logging training status")
|
||||
parser.add_argument("--save_model", action="store_true",
|
||||
help="whether to save the model")
|
||||
parser.add_argument("--cached", action="store_true",
|
||||
help="whether to cache the input representations")
|
||||
parser.add_argument("--no_cuda", action="store_true",
|
||||
help="use to turn off cuda")
|
||||
parser = argparse.ArgumentParser(description="Train a discriminator on top of GPT-2 representations")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="SST",
|
||||
choices=("SST", "clickbait", "toxic", "generic"),
|
||||
help="dataset to train the discriminator on."
|
||||
"In case of generic, the dataset is expected"
|
||||
"to be a TSBV file with structure: class \\t text",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_fp",
|
||||
type=str,
|
||||
default="",
|
||||
help="File path of the dataset to use. " "Needed only in case of generic datadset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="Number of training epochs")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="whether to save the model")
|
||||
parser.add_argument("--cached", action="store_true", help="whether to cache the input representations")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="use to turn off cuda")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_discriminator(**(vars(args)))
|
||||
|
||||
@@ -32,10 +32,18 @@ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subse
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers import (WEIGHTS_NAME,
|
||||
BertConfig, BertForSequenceClassification, BertTokenizer,
|
||||
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
||||
XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForSequenceClassification,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer,
|
||||
)
|
||||
|
||||
from run_glue import set_seed, load_and_cache_examples, ALL_MODELS, MODEL_CLASSES
|
||||
|
||||
@@ -63,7 +71,9 @@ def print_2d_tensor(tensor):
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
||||
|
||||
|
||||
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
|
||||
):
|
||||
""" This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
@@ -85,8 +95,14 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||
outputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, head_mask=head_mask)
|
||||
loss, logits, all_attentions = outputs[0], outputs[1], outputs[-1] # Loss and logits are the first, attention the last
|
||||
outputs = model(
|
||||
input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, head_mask=head_mask
|
||||
)
|
||||
loss, logits, all_attentions = (
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
outputs[-1],
|
||||
) # Loss and logits are the first, attention the last
|
||||
loss.backward() # Backpropagate to populate the gradients in the head mask
|
||||
|
||||
if compute_entropy:
|
||||
@@ -113,15 +129,15 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
||||
# Layerwise importance normalization
|
||||
if not args.dont_normalize_importance_by_layer:
|
||||
exponent = 2
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
|
||||
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
||||
|
||||
if not args.dont_normalize_global_importance:
|
||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||
|
||||
# Print/save matrices
|
||||
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
|
||||
|
||||
logger.info("Attention entropies")
|
||||
print_2d_tensor(attn_entropy)
|
||||
@@ -129,7 +145,9 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
||||
print_2d_tensor(head_importance)
|
||||
logger.info("Head ranked by importance scores")
|
||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device)
|
||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
||||
head_importance.numel(), device=args.device
|
||||
)
|
||||
head_ranks = head_ranks.view_as(head_importance)
|
||||
print_2d_tensor(head_ranks)
|
||||
|
||||
@@ -150,9 +168,9 @@ def mask_heads(args, model, eval_dataloader):
|
||||
|
||||
current_score = original_score
|
||||
while current_score >= original_score * args.masking_threshold:
|
||||
head_mask = new_head_mask.clone() # save current head mask
|
||||
head_mask = new_head_mask.clone() # save current head mask
|
||||
# heads from least important to most - keep only not-masked heads
|
||||
head_importance[head_mask == 0.0] = float('Inf')
|
||||
head_importance[head_mask == 0.0] = float("Inf")
|
||||
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||
|
||||
if len(current_heads_to_mask) <= num_to_mask:
|
||||
@@ -167,14 +185,21 @@ def mask_heads(args, model, eval_dataloader):
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
|
||||
_, head_importance, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
|
||||
logger.info(
|
||||
"Masking: current score: %f, remaning heads %d (%.1f percents)",
|
||||
current_score,
|
||||
new_head_mask.sum(),
|
||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||
)
|
||||
|
||||
logger.info("Final head mask")
|
||||
print_2d_tensor(head_mask)
|
||||
np.save(os.path.join(args.output_dir, 'head_mask.npy'), head_mask.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
||||
|
||||
return head_mask
|
||||
|
||||
@@ -186,8 +211,9 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
||||
compute_entropy=False, compute_importance=False, head_mask=head_mask)
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_masking = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
original_time = datetime.now() - before_time
|
||||
@@ -199,73 +225,127 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
before_time = datetime.now()
|
||||
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
||||
compute_entropy=False, compute_importance=False, head_mask=None)
|
||||
_, _, preds, labels = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
|
||||
)
|
||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||
score_pruning = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||
new_time = datetime.now() - before_time
|
||||
|
||||
logger.info("Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params/original_num_params * 100)
|
||||
logger.info(
|
||||
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
||||
original_num_params,
|
||||
pruned_num_params,
|
||||
pruned_num_params / original_num_params * 100,
|
||||
)
|
||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
|
||||
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(
|
||||
ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name_or_path")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--data_subset", type=int, default=-1,
|
||||
help="If > 0: limit the data to a subset of data_subset instances.")
|
||||
parser.add_argument("--overwrite_output_dir", action='store_true',
|
||||
help="Whether to overwrite data in output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
parser.add_argument("--dont_normalize_importance_by_layer", action='store_true',
|
||||
help="Don't normalize importance score by layers")
|
||||
parser.add_argument("--dont_normalize_global_importance", action='store_true',
|
||||
help="Don't normalize all importance scores between 0 and 1")
|
||||
parser.add_argument(
|
||||
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_normalize_global_importance",
|
||||
action="store_true",
|
||||
help="Don't normalize all importance scores between 0 and 1",
|
||||
)
|
||||
|
||||
parser.add_argument("--try_masking", action='store_true',
|
||||
help="Whether to try to mask head until a threshold of accuracy.")
|
||||
parser.add_argument("--masking_threshold", default=0.9, type=float,
|
||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).")
|
||||
parser.add_argument("--masking_amount", default=0.1, type=float,
|
||||
help="Amount to heads to masking at each masking step.")
|
||||
parser.add_argument("--metric_name", default="acc", type=str,
|
||||
help="Metric to use for head masking.")
|
||||
parser.add_argument(
|
||||
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_threshold",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
||||
)
|
||||
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, sequences shorter padded.")
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, sequences shorter padded.",
|
||||
)
|
||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -278,10 +358,10 @@ def main():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
torch.distributed.init_process_group(backend='nccl') # Initializes the distributed backend
|
||||
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
|
||||
|
||||
# Set seeds
|
||||
@@ -306,17 +386,23 @@ def main():
|
||||
args.model_type = key # take the first match in model types
|
||||
break
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
output_attentions=True,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
output_attentions=True,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -324,14 +410,14 @@ def main():
|
||||
# Distributed and parallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
elif args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Print/save training arguments
|
||||
torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset for the GLUE task
|
||||
@@ -341,11 +427,9 @@ def main():
|
||||
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
|
||||
|
||||
|
||||
# Compute head entropy and importance score
|
||||
compute_heads_importance(args, model, eval_dataloader)
|
||||
|
||||
|
||||
# Try head masking (set heads to zero until the score goes under a threshole)
|
||||
# and head pruning (remove masked heads and see the effect on the network)
|
||||
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||
@@ -353,5 +437,5 @@ def main():
|
||||
prune_heads(args, model, eval_dataloader, head_mask)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -33,9 +33,7 @@ from transformers import XLMWithLMHeadModel, XLMTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,6 +69,7 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
#
|
||||
# Functions to prepare models' input
|
||||
#
|
||||
@@ -78,15 +77,11 @@ def set_seed(args):
|
||||
|
||||
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||
if args.temperature > 0.7:
|
||||
logger.info(
|
||||
"CTRL typically works better with lower temperatures (and lower top_k)."
|
||||
)
|
||||
logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
|
||||
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
|
||||
logger.info(
|
||||
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
||||
)
|
||||
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||
return prompt_text
|
||||
|
||||
|
||||
@@ -102,11 +97,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
else:
|
||||
language = None
|
||||
while language not in available_languages:
|
||||
language = input(
|
||||
"Using XLM. Select language in "
|
||||
+ str(list(available_languages))
|
||||
+ " >>> "
|
||||
)
|
||||
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
|
||||
# kwargs["language"] = tokenizer.lang2id[language]
|
||||
|
||||
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||
@@ -148,17 +139,34 @@ def adjust_length_to_model(length, max_sequence_length):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
|
||||
)
|
||||
parser.add_argument("--k", type=int, default=0)
|
||||
parser.add_argument("--p", type=float, default=0.9)
|
||||
|
||||
@@ -169,9 +177,7 @@ def main():
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||
)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
set_seed(args)
|
||||
@@ -181,17 +187,13 @@ def main():
|
||||
args.model_type = args.model_type.lower()
|
||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
||||
)
|
||||
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
|
||||
args.length = adjust_length_to_model(
|
||||
args.length, max_sequence_length=model.config.max_position_embeddings
|
||||
)
|
||||
args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
|
||||
logger.info(args)
|
||||
|
||||
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||
@@ -201,7 +203,7 @@ def main():
|
||||
if requires_preprocessing:
|
||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
|
||||
output_sequences = model.generate(
|
||||
input_ids=encoded_prompt,
|
||||
|
||||
@@ -26,8 +26,7 @@ import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -37,25 +36,30 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForSequenceClassification, BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaTokenizer,
|
||||
XLMConfig, XLMForSequenceClassification,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
AlbertConfig,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertTokenizer,
|
||||
XLMRobertaConfig,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaTokenizer,
|
||||
)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMForSequenceClassification,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
AlbertConfig,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertTokenizer,
|
||||
XLMRobertaConfig,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
@@ -66,17 +70,22 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
|
||||
RobertaConfig, DistilBertConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
||||
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@@ -104,20 +113,27 @@ def train(args, train_dataset, model, tokenizer):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
@@ -132,17 +148,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -152,7 +172,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
@@ -163,7 +183,9 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
@@ -176,16 +198,16 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'labels': batch[3]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -209,36 +231,40 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
eval_key = 'eval_{}'.format(key)
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs['learning_rate'] = learning_rate_scalar
|
||||
logs['loss'] = loss_scalar
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{'step': global_step}}))
|
||||
print(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -257,7 +283,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
@@ -288,11 +314,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'labels': batch[3]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@@ -300,10 +326,10 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
@@ -330,29 +356,36 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
|
||||
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
features = convert_examples_to_features(examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
@@ -378,90 +411,149 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -473,16 +565,24 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -502,17 +602,23 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -521,14 +627,12 @@ def main():
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -538,36 +642,39 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
@@ -42,37 +42,55 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
|
||||
BertConfig, BertForMaskedLM, BertTokenizer,
|
||||
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
|
||||
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer,
|
||||
CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
get_linear_schedule_with_warmup,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertTokenizer,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
RobertaTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertTokenizer,
|
||||
CamembertConfig,
|
||||
CamembertForMaskedLM,
|
||||
CamembertTokenizer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
'camembert': (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, tokenizer, args, file_path='train', block_size=512):
|
||||
def __init__(self, tokenizer, args, file_path="train", block_size=512):
|
||||
assert os.path.isfile(file_path)
|
||||
directory, filename = os.path.split(file_path)
|
||||
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
|
||||
cached_features_file = os.path.join(
|
||||
directory, args.model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename
|
||||
)
|
||||
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, 'rb') as handle:
|
||||
with open(cached_features_file, "rb") as handle:
|
||||
self.examples = pickle.load(handle)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", directory)
|
||||
@@ -83,14 +101,14 @@ class TextDataset(Dataset):
|
||||
|
||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||
|
||||
for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size
|
||||
self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size]))
|
||||
for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
|
||||
self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))
|
||||
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
||||
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
||||
# can change this behavior by adding (model specific) padding.
|
||||
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
with open(cached_features_file, 'wb') as handle:
|
||||
with open(cached_features_file, "wb") as handle:
|
||||
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
@@ -101,7 +119,12 @@ class TextDataset(Dataset):
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||
dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
||||
dataset = TextDataset(
|
||||
tokenizer,
|
||||
args,
|
||||
file_path=args.eval_data_file if evaluate else args.train_data_file,
|
||||
block_size=args.block_size,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
@@ -120,7 +143,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
||||
return
|
||||
|
||||
# Check if we should delete older checkpoint(s)
|
||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
|
||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
|
||||
if len(glob_checkpoints) <= args.save_total_limit:
|
||||
return
|
||||
|
||||
@@ -129,7 +152,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
||||
if use_mtime:
|
||||
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
||||
else:
|
||||
regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
|
||||
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
|
||||
if regex_match and regex_match.groups():
|
||||
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
||||
|
||||
@@ -147,7 +170,9 @@ def mask_tokens(inputs, tokenizer, args):
|
||||
labels = inputs.clone()
|
||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
||||
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
|
||||
special_tokens_mask = [
|
||||
tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||
]
|
||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
@@ -181,19 +206,26 @@ def train(args, train_dataset, model, tokenizer):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
@@ -208,17 +240,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -228,7 +264,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
@@ -239,11 +275,13 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
|
||||
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
||||
model_to_resize.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
model.zero_grad()
|
||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
@@ -285,31 +323,35 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
checkpoint_prefix = "checkpoint"
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
_rotate_checkpoints(args, checkpoint_prefix)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -365,9 +407,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
perplexity = torch.exp(torch.tensor(eval_loss))
|
||||
|
||||
result = {
|
||||
"perplexity": perplexity
|
||||
}
|
||||
result = {"perplexity": perplexity}
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
@@ -383,107 +423,167 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
||||
help="The input training data file (a text file).")
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--eval_data_file", default=None, type=str,
|
||||
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
||||
parser.add_argument(
|
||||
"--eval_data_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
|
||||
)
|
||||
|
||||
parser.add_argument("--model_type", default="bert", type=str,
|
||||
help="The model architecture to be fine-tuned.")
|
||||
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
|
||||
help="The model checkpoint for weights initialization.")
|
||||
parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint for weights initialization.",
|
||||
)
|
||||
|
||||
parser.add_argument("--mlm", action='store_true',
|
||||
help="Train with masked-language modeling loss instead of language modeling.")
|
||||
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
||||
help="Ratio of tokens to mask for masked language modeling loss")
|
||||
parser.add_argument(
|
||||
"--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
|
||||
)
|
||||
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
||||
parser.add_argument("--block_size", default=-1, type=int,
|
||||
help="Optional input sequence length after tokenization."
|
||||
"The training dataset will be truncated in block of this size for training."
|
||||
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Run evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional pretrained config name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="Optional input sequence length after tokenization."
|
||||
"The training dataset will be truncated in block of this size for training."
|
||||
"Default to the model max input length for single sentence inputs (take into account special tokens).",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument('--save_total_limit', type=int, default=None,
|
||||
help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--save_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
||||
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||
"flag (masked language modeling).")
|
||||
raise ValueError(
|
||||
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||
"flag (masked language modeling)."
|
||||
)
|
||||
if args.eval_data_file is None and args.do_eval:
|
||||
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
||||
"or remove the --do_eval argument.")
|
||||
raise ValueError(
|
||||
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
||||
"or remove the --do_eval argument."
|
||||
)
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -495,16 +595,24 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -514,18 +622,26 @@ def main():
|
||||
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
if args.block_size <= 0:
|
||||
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
|
||||
args.block_size = (
|
||||
tokenizer.max_len_single_sentence
|
||||
) # Our input block size will be the max possible for the model
|
||||
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model.to(args.device)
|
||||
|
||||
if args.local_rank == 0:
|
||||
@@ -546,7 +662,6 @@ def main():
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -556,35 +671,38 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
@@ -26,8 +26,7 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -37,34 +36,38 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForMultipleChoice, BertTokenizer,
|
||||
XLNetConfig, XLNetForMultipleChoice,
|
||||
XLNetTokenizer, RobertaConfig,
|
||||
RobertaForMultipleChoice, RobertaTokenizer)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForMultipleChoice,
|
||||
BertTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForMultipleChoice,
|
||||
XLNetTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from utils_multiple_choice import (convert_examples_to_features, processors)
|
||||
from utils_multiple_choice import convert_examples_to_features, processors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer)
|
||||
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def select_field(features, field):
|
||||
return [
|
||||
[
|
||||
choice[field]
|
||||
for choice in feature.choices_features
|
||||
]
|
||||
for feature in features
|
||||
]
|
||||
return [[choice[field] for choice in feature.choices_features] for feature in features]
|
||||
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
@@ -95,13 +98,18 @@ def train(args, train_dataset, model, tokenizer):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
@@ -115,17 +123,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -141,15 +153,19 @@ def train(args, train_dataset, model, tokenizer):
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2]
|
||||
if args.model_type in ["bert", "xlnet"]
|
||||
else None, # XLM don't use segment_ids
|
||||
"labels": batch[3],
|
||||
}
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -171,10 +187,12 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
if results["eval_acc"] > best_dev_acc:
|
||||
best_dev_acc = results["eval_acc"]
|
||||
best_dev_loss = results["eval_loss"]
|
||||
@@ -182,22 +200,33 @@ def train(args, train_dataset, model, tokenizer):
|
||||
if args.do_test:
|
||||
results_test = evaluate(args, model, tokenizer, test=True)
|
||||
for key, value in results_test.items():
|
||||
tb_writer.add_scalar('test_{}'.format(key), value, global_step)
|
||||
logger.info("test acc: %s, loss: %s, global steps: %s", str(results_test['eval_acc']), str(results_test['eval_loss']), str(global_step))
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step))
|
||||
tb_writer.add_scalar("test_{}".format(key), value, global_step)
|
||||
logger.info(
|
||||
"test acc: %s, loss: %s, global steps: %s",
|
||||
str(results_test["eval_acc"]),
|
||||
str(results_test["eval_loss"]),
|
||||
str(global_step),
|
||||
)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logger.info(
|
||||
"Average loss: %s at global step: %s",
|
||||
str((tr_loss - logging_loss) / args.logging_steps),
|
||||
str(global_step),
|
||||
)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_vocabulary(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -246,10 +275,14 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2]
|
||||
if args.model_type in ["bert", "xlnet"]
|
||||
else None, # XLM don't use segment_ids
|
||||
"labels": batch[3],
|
||||
}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@@ -257,10 +290,10 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = np.argmax(preds, axis=1)
|
||||
@@ -273,8 +306,14 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
|
||||
writer.write("model =%s\n" % str(args.model_name_or_path))
|
||||
writer.write("total batch size=%d\n" % (args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
|
||||
(torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
|
||||
writer.write(
|
||||
"total batch size=%d\n"
|
||||
% (
|
||||
args.per_gpu_train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1)
|
||||
)
|
||||
)
|
||||
writer.write("train num epochs=%d\n" % args.num_train_epochs)
|
||||
writer.write("fp16 =%s\n" % args.fp16)
|
||||
writer.write("max seq length =%d\n" % args.max_seq_length)
|
||||
@@ -291,17 +330,21 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
||||
processor = processors[task]()
|
||||
# Load data features from cache or dataset file
|
||||
if evaluate:
|
||||
cached_mode = 'dev'
|
||||
cached_mode = "dev"
|
||||
elif test:
|
||||
cached_mode = 'test'
|
||||
cached_mode = "test"
|
||||
else:
|
||||
cached_mode = 'train'
|
||||
cached_mode = "train"
|
||||
assert (evaluate == True and test == True) == False
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||
cached_mode,
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
cached_mode,
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
@@ -320,8 +363,8 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
||||
label_list,
|
||||
args.max_seq_length,
|
||||
tokenizer,
|
||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0
|
||||
pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
|
||||
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
@@ -331,9 +374,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
|
||||
all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
@@ -344,91 +387,150 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set')
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Run evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_test", action="store_true", help="Whether to run test on the test set")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -440,16 +542,24 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -468,17 +578,23 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -494,7 +610,6 @@ def main():
|
||||
global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -504,19 +619,20 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
@@ -524,17 +640,19 @@ def main():
|
||||
args.output_dir = args.model_name_or_path
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
if args.do_test and args.local_rank in [-1, 0]:
|
||||
@@ -546,13 +664,13 @@ def main():
|
||||
# logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix, test=True)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
if best_steps:
|
||||
logger.info("best steps of eval acc is the following checkpoints: %s", best_steps)
|
||||
|
||||
@@ -43,9 +43,12 @@ from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification, XLM
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig,
|
||||
CamembertConfig, XLMRobertaConfig)),
|
||||
())
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
@@ -82,18 +85,24 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
@@ -108,18 +117,21 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (
|
||||
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -129,7 +141,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
@@ -140,7 +152,9 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
@@ -153,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]}
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
@@ -187,7 +201,9 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
@@ -200,15 +216,17 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -249,11 +267,11 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]}
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@@ -287,7 +305,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
||||
"loss": eval_loss,
|
||||
"precision": precision_score(out_label_list, preds_list),
|
||||
"recall": recall_score(out_label_list, preds_list),
|
||||
"f1": f1_score(out_label_list, preds_list)
|
||||
"f1": f1_score(out_label_list, preds_list),
|
||||
}
|
||||
|
||||
logger.info("***** Eval results %s *****", prefix)
|
||||
@@ -302,29 +320,36 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format(mode,
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
mode, list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length)
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
examples = read_examples_from_file(args.data_dir, mode)
|
||||
features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer,
|
||||
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args.model_type in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args.model_type in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
args.max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args.model_type in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args.model_type in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
@@ -346,95 +371,151 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--labels", default="", type=str,
|
||||
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action="store_true",
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true",
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_predict", action="store_true",
|
||||
help="Whether to run predictions on the test set.")
|
||||
parser.add_argument("--evaluate_during_training", action="store_true",
|
||||
help="Whether to run evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action="store_true",
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training",
|
||||
action="store_true",
|
||||
help="Whether to run evaluation during training at each logging step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action="store_true",
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true",
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument("--overwrite_cache", action="store_true",
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument("--fp16_opt_level", type=str, default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(
|
||||
args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir))
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -451,11 +532,19 @@ def main():
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -472,16 +561,22 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -505,7 +600,9 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
@@ -518,7 +615,9 @@ def main():
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
@@ -565,4 +664,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -17,7 +17,11 @@
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
||||
from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
|
||||
from transformers.data.metrics.squad_metrics import (
|
||||
compute_predictions_logits,
|
||||
compute_predictions_log_probs,
|
||||
squad_evaluate,
|
||||
)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@@ -27,8 +31,7 @@ import glob
|
||||
import timeit
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (
|
||||
DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -38,32 +41,47 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
|
||||
AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer,
|
||||
XLMConfig, XLMForQuestionAnswering, XLMTokenizer,
|
||||
)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaTokenizer,
|
||||
RobertaConfig,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertTokenizer,
|
||||
AlbertConfig,
|
||||
AlbertForQuestionAnswering,
|
||||
AlbertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@@ -85,49 +103,44 @@ def train(args, train_dataset, model, tokenizer):
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(
|
||||
train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (
|
||||
len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(
|
||||
train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(
|
||||
nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(
|
||||
nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(
|
||||
os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||
scheduler.load_state_dict(torch.load(
|
||||
os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
model, optimizer = amp.initialize(
|
||||
model, optimizer, opt_level=args.fp16_opt_level)
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
@@ -135,20 +148,22 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d",
|
||||
args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d",
|
||||
args.gradient_accumulation_steps)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 1
|
||||
@@ -157,29 +172,25 @@ def train(args, train_dataset, model, tokenizer):
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) //
|
||||
args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (
|
||||
len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(
|
||||
" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch",
|
||||
steps_trained_in_current_epoch)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(epochs_trained, int(
|
||||
args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
# Added here for reproductibility (even between python 2 and 3)
|
||||
set_seed(args)
|
||||
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration",
|
||||
disable=args.local_rank not in [-1, 0])
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
@@ -191,18 +202,17 @@ def train(args, train_dataset, model, tokenizer):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4],
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
||||
"start_positions": batch[3],
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
if args.version_2_with_negative:
|
||||
inputs.update({'is_impossible': batch[7]})
|
||||
inputs.update({"is_impossible": batch[7]})
|
||||
outputs = model(**inputs)
|
||||
# model outputs are always tuple in transformers (see doc)
|
||||
loss = outputs[0]
|
||||
@@ -221,11 +231,9 @@ def train(args, train_dataset, model, tokenizer):
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(optimizer), args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
@@ -238,36 +246,27 @@ def train(args, train_dataset, model, tokenizer):
|
||||
if args.local_rank == -1 and args.evaluate_during_training:
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar(
|
||||
'eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar(
|
||||
'lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar(
|
||||
'loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
# Save model checkpoint
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
output_dir = os.path.join(
|
||||
args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(
|
||||
model, 'module') else model
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(
|
||||
output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(
|
||||
output_dir, 'optimizer.pt'))
|
||||
torch.save(scheduler.state_dict(), os.path.join(
|
||||
output_dir, 'scheduler.pt'))
|
||||
logger.info(
|
||||
"Saving optimizer and scheduler states to %s", output_dir)
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
@@ -283,8 +282,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
dataset, examples, features = load_and_cache_examples(
|
||||
args, tokenizer, evaluate=True, output_examples=True)
|
||||
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
@@ -293,8 +291,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu evaluate
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
@@ -314,15 +311,15 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
||||
}
|
||||
example_indices = batch[3]
|
||||
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
@@ -342,53 +339,68 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits,
|
||||
unique_id,
|
||||
start_logits,
|
||||
end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits
|
||||
cls_logits=cls_logits,
|
||||
)
|
||||
|
||||
else:
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits
|
||||
)
|
||||
result = SquadResult(unique_id, start_logits, end_logits)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
logger.info(" Evaluation done in total %f secs (%f sec per example)",
|
||||
evalTime, evalTime / len(dataset))
|
||||
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(
|
||||
args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(
|
||||
args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(
|
||||
args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
# XLNet and XLM use a more complex post-processing procedure
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
start_n_top = model.config.start_n_top if hasattr(
|
||||
model, "config") else model.module.config.start_n_top
|
||||
end_n_top = model.config.end_n_top if hasattr(
|
||||
model, "config") else model.module.config.end_n_top
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
|
||||
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
|
||||
|
||||
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file,
|
||||
start_n_top, end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
predictions = compute_predictions_log_probs(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
start_n_top,
|
||||
end_n_top,
|
||||
args.version_2_with_negative,
|
||||
tokenizer,
|
||||
args.verbose_logging,
|
||||
)
|
||||
else:
|
||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
|
||||
predictions = compute_predictions_logits(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
args.do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.verbose_logging,
|
||||
args.version_2_with_negative,
|
||||
args.null_score_diff_threshold,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
@@ -402,16 +414,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_dir = args.data_dir if args.data_dir else "."
|
||||
cached_features_file = os.path.join(input_dir, 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length))
|
||||
cached_features_file = os.path.join(
|
||||
input_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
|
||||
# Init features and dataset from cache if it exists
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s",
|
||||
cached_features_file)
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
||||
else:
|
||||
@@ -421,16 +435,13 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
try:
|
||||
import tensorflow_datasets as tfds
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||
|
||||
if args.version_2_with_negative:
|
||||
logger.warn(
|
||||
"tensorflow_datasets does not handle version 2 of SQuAD.")
|
||||
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
|
||||
|
||||
tfds_examples = tfds.load("squad")
|
||||
examples = SquadV1Processor().get_examples_from_dataset(
|
||||
tfds_examples, evaluate=evaluate)
|
||||
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
||||
else:
|
||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||
if evaluate:
|
||||
@@ -445,15 +456,13 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
return_dataset='pt',
|
||||
return_dataset="pt",
|
||||
threads=args.threads,
|
||||
)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s",
|
||||
cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset},
|
||||
cached_features_file)
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
@@ -468,140 +477,232 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str,
|
||||
help="The input data dir. Should contain the .json files for the task." +
|
||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||
parser.add_argument("--train_file", default=None, type=str,
|
||||
help="The input training file. If a data dir is specified, will look for the file there" +
|
||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||
parser.add_argument("--predict_file", default=None, type=str,
|
||||
help="The input evaluation file. If a data dir is specified, will look for the file there" +
|
||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input data dir. Should contain the .json files for the task."
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input training file. If a data dir is specified, will look for the file there"
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
|
||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
parser.add_argument(
|
||||
"--version_2_with_negative",
|
||||
action="store_true",
|
||||
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--null_score_diff_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||
)
|
||||
|
||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--doc_stride", default=128, type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
||||
parser.add_argument("--max_query_length", default=64, type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
default=128,
|
||||
type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_query_length",
|
||||
default=64,
|
||||
type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--n_best_size", default=20, type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.")
|
||||
parser.add_argument("--verbose_logging", action='store_true',
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--n_best_size",
|
||||
default=20,
|
||||
type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose_logging",
|
||||
action="store_true",
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
|
||||
parser.add_argument('--threads', type=int, default=1, help='multiple threads for converting example to features')
|
||||
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(
|
||||
address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -613,16 +714,21 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(
|
||||
'.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -638,18 +744,16 @@ def main():
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
apex.amp.register_half_function(torch, 'einsum')
|
||||
|
||||
apex.amp.register_half_function(torch, "einsum")
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(
|
||||
args, tokenizer, evaluate=False, output_examples=False)
|
||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s",
|
||||
global_step, tr_loss)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
@@ -661,18 +765,16 @@ def main():
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
# Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(
|
||||
args.output_dir, force_download=True)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.output_dir, force_download=True)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
@@ -682,7 +784,10 @@ def main():
|
||||
logger.info("Loading checkpoints saved during training for evaluation")
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c)
|
||||
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
else:
|
||||
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||||
@@ -692,17 +797,14 @@ def main():
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split(
|
||||
'-')[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(
|
||||
checkpoint, force_download=True)
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint, force_download=True)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v)
|
||||
for k, v in result.items())
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import os
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets
|
||||
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig, glue_convert_examples_to_features, BertForSequenceClassification, glue_processors
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
TFBertForSequenceClassification,
|
||||
BertConfig,
|
||||
glue_convert_examples_to_features,
|
||||
BertForSequenceClassification,
|
||||
glue_processors,
|
||||
)
|
||||
|
||||
# script parameters
|
||||
BATCH_SIZE = 32
|
||||
@@ -27,21 +34,21 @@ tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
||||
|
||||
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
|
||||
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased', config=config)
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
model = TFBertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
|
||||
|
||||
# Load dataset via TensorFlow Datasets
|
||||
data, info = tensorflow_datasets.load(f'glue/{TFDS_TASK}', with_info=True)
|
||||
train_examples = info.splits['train'].num_examples
|
||||
data, info = tensorflow_datasets.load(f"glue/{TFDS_TASK}", with_info=True)
|
||||
train_examples = info.splits["train"].num_examples
|
||||
|
||||
# MNLI expects either validation_matched or validation_mismatched
|
||||
valid_examples = info.splits['validation'].num_examples
|
||||
valid_examples = info.splits["validation"].num_examples
|
||||
|
||||
# Prepare dataset for GLUE as a tf.data.Dataset instance
|
||||
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, TASK)
|
||||
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, 128, TASK)
|
||||
|
||||
# MNLI expects either validation_matched or validation_mismatched
|
||||
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, TASK)
|
||||
valid_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, 128, TASK)
|
||||
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
||||
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||
|
||||
@@ -49,7 +56,7 @@ valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||
if USE_AMP:
|
||||
# loss scaling is currently required when using mixed precision
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
|
||||
|
||||
|
||||
if num_labels == 1:
|
||||
@@ -57,37 +64,42 @@ if num_labels == 1:
|
||||
else:
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||
|
||||
# Train and evaluate using tf.keras.Model.fit()
|
||||
train_steps = train_examples//BATCH_SIZE
|
||||
valid_steps = valid_examples//EVAL_BATCH_SIZE
|
||||
train_steps = train_examples // BATCH_SIZE
|
||||
valid_steps = valid_examples // EVAL_BATCH_SIZE
|
||||
|
||||
history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
|
||||
validation_data=valid_dataset, validation_steps=valid_steps)
|
||||
history = model.fit(
|
||||
train_dataset,
|
||||
epochs=EPOCHS,
|
||||
steps_per_epoch=train_steps,
|
||||
validation_data=valid_dataset,
|
||||
validation_steps=valid_steps,
|
||||
)
|
||||
|
||||
# Save TF2 model
|
||||
os.makedirs('./save/', exist_ok=True)
|
||||
model.save_pretrained('./save/')
|
||||
os.makedirs("./save/", exist_ok=True)
|
||||
model.save_pretrained("./save/")
|
||||
|
||||
if TASK == "mrpc":
|
||||
# Load the TensorFlow model in PyTorch for inspection
|
||||
# This is to demo the interoperability between the two frameworks, you don't have to
|
||||
# do this in real life (you can run the inference on the TF model).
|
||||
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
|
||||
pytorch_model = BertForSequenceClassification.from_pretrained("./save/", from_tf=True)
|
||||
|
||||
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
|
||||
sentence_0 = 'This research was consistent with his findings.'
|
||||
sentence_1 = 'His findings were compatible with this research.'
|
||||
sentence_2 = 'His findings were not compatible with this research.'
|
||||
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
|
||||
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
|
||||
sentence_0 = "This research was consistent with his findings."
|
||||
sentence_1 = "His findings were compatible with this research."
|
||||
sentence_2 = "His findings were not compatible with this research."
|
||||
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors="pt")
|
||||
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors="pt")
|
||||
|
||||
del inputs_1["special_tokens_mask"]
|
||||
del inputs_2["special_tokens_mask"]
|
||||
|
||||
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
||||
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
||||
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
|
||||
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')
|
||||
print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")
|
||||
print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")
|
||||
|
||||
@@ -21,189 +21,156 @@ from absl import app
|
||||
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||
())
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer)
|
||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
"data_dir", None,
|
||||
"The input data dir. Should contain the .conll files (or other data files) "
|
||||
"for the task.")
|
||||
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
||||
)
|
||||
|
||||
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_type", None,
|
||||
"Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
"model_name_or_path",
|
||||
None,
|
||||
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
|
||||
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_name_or_path", None,
|
||||
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
"labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
|
||||
)
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_dir", None,
|
||||
"The output directory where the model checkpoints will be written.")
|
||||
flags.DEFINE_string("config_name", "", "Pretrained config name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"labels", "",
|
||||
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
||||
flags.DEFINE_string("tokenizer_name", "", "Pretrained tokenizer name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"config_name", "",
|
||||
"Pretrained config name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tokenizer_name", "",
|
||||
"Pretrained tokenizer name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"cache_dir", "",
|
||||
"Where do you want to store the pre-trained models downloaded from s3")
|
||||
flags.DEFINE_string("cache_dir", "", "Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_seq_length", 128,
|
||||
"max_seq_length",
|
||||
128,
|
||||
"The maximum total input sentence length after tokenization. "
|
||||
"Sequences longer than this will be truncated, sequences shorter "
|
||||
"will be padded.")
|
||||
"will be padded.",
|
||||
)
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tpu", None,
|
||||
"tpu",
|
||||
None,
|
||||
"The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||
"url.")
|
||||
"url.",
|
||||
)
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_tpu_cores", 8,
|
||||
"Total number of TPU cores to use.")
|
||||
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
|
||||
|
||||
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
|
||||
|
||||
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
|
||||
|
||||
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_train", False,
|
||||
"Whether to run training.")
|
||||
"evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
|
||||
)
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_eval", False,
|
||||
"Whether to run eval on the dev set.")
|
||||
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_predict", False,
|
||||
"Whether to run predictions on the test set.")
|
||||
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"evaluate_during_training", False,
|
||||
"Whether to run evaluation during training at each logging step.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_lower_case", False,
|
||||
"Set this flag if you are using an uncased model.")
|
||||
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"per_device_train_batch_size", 8,
|
||||
"Batch size per GPU/CPU/TPU for training.")
|
||||
"gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
|
||||
)
|
||||
|
||||
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
||||
|
||||
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
|
||||
|
||||
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
|
||||
|
||||
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
|
||||
|
||||
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"per_device_eval_batch_size", 8,
|
||||
"Batch size per GPU/CPU/TPU for evaluation.")
|
||||
"max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
|
||||
)
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"gradient_accumulation_steps", 1,
|
||||
"Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"learning_rate", 5e-5,
|
||||
"The initial learning rate for Adam.")
|
||||
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"weight_decay", 0.0,
|
||||
"Weight decay if we apply some.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"adam_epsilon", 1e-8,
|
||||
"Epsilon for Adam optimizer.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"max_grad_norm", 1.0,
|
||||
"Max gradient norm.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_train_epochs", 3,
|
||||
"Total number of training epochs to perform.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_steps", -1,
|
||||
"If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"warmup_steps", 0,
|
||||
"Linear warmup over warmup_steps.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"logging_steps", 50,
|
||||
"Log every X updates steps.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"save_steps", 50,
|
||||
"Save checkpoint every X updates steps.")
|
||||
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"eval_all_checkpoints", False,
|
||||
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
"eval_all_checkpoints",
|
||||
False,
|
||||
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"no_cuda", False,
|
||||
"Avoid using CUDA when available")
|
||||
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA when available")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"overwrite_output_dir", False,
|
||||
"Overwrite the content of the output directory")
|
||||
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"overwrite_cache", False,
|
||||
"Overwrite the cached training and evaluation sets")
|
||||
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"seed", 42,
|
||||
"random seed for initialization")
|
||||
flags.DEFINE_integer("seed", 42, "random seed for initialization")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"fp16", False,
|
||||
"Whether to use 16-bit (mixed) precision instead of 32-bit")
|
||||
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"gpus", "0",
|
||||
"gpus",
|
||||
"0",
|
||||
"Comma separated list of gpus devices. If only one, switch to single "
|
||||
"gpu strategy, if None takes all the gpus available.")
|
||||
"gpu strategy, if None takes all the gpus available.",
|
||||
)
|
||||
|
||||
|
||||
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id):
|
||||
if args['max_steps'] > 0:
|
||||
num_train_steps = args['max_steps'] * args['gradient_accumulation_steps']
|
||||
args['num_train_epochs'] = 1
|
||||
def train(
|
||||
args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
|
||||
):
|
||||
if args["max_steps"] > 0:
|
||||
num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
|
||||
args["num_train_epochs"] = 1
|
||||
else:
|
||||
num_train_steps = math.ceil(num_train_examples / train_batch_size) // args['gradient_accumulation_steps'] * args['num_train_epochs']
|
||||
num_train_steps = (
|
||||
math.ceil(num_train_examples / train_batch_size)
|
||||
// args["gradient_accumulation_steps"]
|
||||
* args["num_train_epochs"]
|
||||
)
|
||||
|
||||
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
||||
|
||||
with strategy.scope():
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
optimizer = create_optimizer(args['learning_rate'], num_train_steps, args['warmup_steps'])
|
||||
optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
|
||||
|
||||
if args['fp16']:
|
||||
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
|
||||
if args["fp16"]:
|
||||
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
|
||||
|
||||
loss_metric = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
|
||||
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
|
||||
gradient_accumulator = GradientAccumulator()
|
||||
|
||||
logging.info("***** Running training *****")
|
||||
logging.info(" Num examples = %d", num_train_examples)
|
||||
logging.info(" Num Epochs = %d", args['num_train_epochs'])
|
||||
logging.info(" Instantaneous batch size per device = %d", args['per_device_train_batch_size'])
|
||||
logging.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
train_batch_size * args['gradient_accumulation_steps'])
|
||||
logging.info(" Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
|
||||
logging.info(" Num Epochs = %d", args["num_train_epochs"])
|
||||
logging.info(" Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
|
||||
logging.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
train_batch_size * args["gradient_accumulation_steps"],
|
||||
)
|
||||
logging.info(" Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
|
||||
logging.info(" Total training steps = %d", num_train_steps)
|
||||
|
||||
model.summary()
|
||||
@@ -214,26 +181,28 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
||||
|
||||
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
||||
if gradient is not None:
|
||||
scaled_gradient = gradient / (args['n_device'] * args['gradient_accumulation_steps'])
|
||||
scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
|
||||
grads_and_vars.append((scaled_gradient, variable))
|
||||
else:
|
||||
grads_and_vars.append((gradient, variable))
|
||||
|
||||
optimizer.apply_gradients(grads_and_vars, args['max_grad_norm'])
|
||||
optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
|
||||
gradient_accumulator.reset()
|
||||
|
||||
@tf.function
|
||||
def train_step(train_features, train_labels):
|
||||
def step_fn(train_features, train_labels):
|
||||
inputs = {'attention_mask': train_features['input_mask'], 'training': True}
|
||||
inputs = {"attention_mask": train_features["input_mask"], "training": True}
|
||||
|
||||
if args['model_type'] != "distilbert":
|
||||
inputs["token_type_ids"] = train_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
||||
if args["model_type"] != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
train_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
|
||||
)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(train_features['input_ids'], **inputs)[0]
|
||||
logits = model(train_features["input_ids"], **inputs)[0]
|
||||
logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||
active_loss = tf.reshape(train_features['input_mask'], (-1,))
|
||||
active_loss = tf.reshape(train_features["input_mask"], (-1,))
|
||||
active_logits = tf.boolean_mask(logits, active_loss)
|
||||
train_labels = tf.reshape(train_labels, (-1,))
|
||||
active_labels = tf.boolean_mask(train_labels, active_loss)
|
||||
@@ -251,29 +220,35 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
||||
return mean_loss
|
||||
|
||||
current_time = datetime.datetime.now()
|
||||
train_iterator = master_bar(range(args['num_train_epochs']))
|
||||
train_iterator = master_bar(range(args["num_train_epochs"]))
|
||||
global_step = 0
|
||||
logging_loss = 0.0
|
||||
|
||||
for epoch in train_iterator:
|
||||
epoch_iterator = progress_bar(train_dataset, total=num_train_steps, parent=train_iterator, display=args['n_device'] > 1)
|
||||
epoch_iterator = progress_bar(
|
||||
train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
|
||||
)
|
||||
step = 1
|
||||
|
||||
with strategy.scope():
|
||||
for train_features, train_labels in epoch_iterator:
|
||||
loss = train_step(train_features, train_labels)
|
||||
|
||||
if step % args['gradient_accumulation_steps'] == 0:
|
||||
if step % args["gradient_accumulation_steps"] == 0:
|
||||
strategy.experimental_run_v2(apply_gradients)
|
||||
|
||||
loss_metric(loss)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
|
||||
if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
|
||||
# Log metrics
|
||||
if args['n_device'] == 1 and args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
if (
|
||||
args["n_device"] == 1 and args["evaluate_during_training"]
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
y_true, y_pred, eval_loss = evaluate(
|
||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||
)
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
||||
@@ -294,16 +269,18 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("lr", learning_rate, global_step)
|
||||
tf.summary.scalar("loss", (loss_metric.result() - logging_loss) / args['logging_steps'], global_step)
|
||||
tf.summary.scalar(
|
||||
"loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
|
||||
)
|
||||
|
||||
logging_loss = loss_metric.result()
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
||||
|
||||
if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
|
||||
if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args['output_dir'], "checkpoint-{}".format(global_step))
|
||||
output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
@@ -311,10 +288,10 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
||||
model.save_pretrained(output_dir)
|
||||
logging.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
train_iterator.child.comment = f'loss : {loss_metric.result()}'
|
||||
train_iterator.child.comment = f"loss : {loss_metric.result()}"
|
||||
step += 1
|
||||
|
||||
train_iterator.write(f'loss epoch {epoch + 1}: {loss_metric.result()}')
|
||||
train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
|
||||
|
||||
loss_metric.reset_states()
|
||||
|
||||
@@ -322,13 +299,15 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
||||
|
||||
|
||||
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
eval_dataset, size = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode)
|
||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||
eval_dataset, size = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
|
||||
)
|
||||
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
||||
preds = None
|
||||
num_eval_steps = math.ceil(size / eval_batch_size)
|
||||
master = master_bar(range(1))
|
||||
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args['n_device'] > 1)
|
||||
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
loss = 0.0
|
||||
|
||||
@@ -337,15 +316,17 @@ def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode)
|
||||
logging.info(" Batch size = %d", eval_batch_size)
|
||||
|
||||
for eval_features, eval_labels in eval_iterator:
|
||||
inputs = {'attention_mask': eval_features['input_mask'], 'training': False}
|
||||
inputs = {"attention_mask": eval_features["input_mask"], "training": False}
|
||||
|
||||
if args['model_type'] != "distilbert":
|
||||
inputs["token_type_ids"] = eval_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
||||
if args["model_type"] != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
eval_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
|
||||
)
|
||||
|
||||
with strategy.scope():
|
||||
logits = model(eval_features['input_ids'], **inputs)[0]
|
||||
logits = model(eval_features["input_ids"], **inputs)[0]
|
||||
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||
active_loss = tf.reshape(eval_features['input_mask'], (-1,))
|
||||
active_loss = tf.reshape(eval_features["input_mask"], (-1,))
|
||||
active_logits = tf.boolean_mask(tmp_logits, active_loss)
|
||||
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
|
||||
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
|
||||
@@ -384,11 +365,11 @@ def load_cache(cached_file, max_seq_length):
|
||||
def _decode_record(record):
|
||||
example = tf.io.parse_single_example(record, name_to_features)
|
||||
features = {}
|
||||
features['input_ids'] = example['input_ids']
|
||||
features['input_mask'] = example['input_mask']
|
||||
features['segment_ids'] = example['segment_ids']
|
||||
features["input_ids"] = example["input_ids"]
|
||||
features["input_mask"] = example["input_mask"]
|
||||
features["segment_ids"] = example["segment_ids"]
|
||||
|
||||
return features, example['label_ids']
|
||||
return features, example["label_ids"]
|
||||
|
||||
d = tf.data.TFRecordDataset(cached_file)
|
||||
d = d.map(_decode_record, num_parallel_calls=4)
|
||||
@@ -422,39 +403,46 @@ def save_cache(features, cached_features_file):
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
||||
drop_remainder = True if args['tpu'] or mode == 'train' else False
|
||||
drop_remainder = True if args["tpu"] or mode == "train" else False
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args['data_dir'], "cached_{}_{}_{}.tf_record".format(mode,
|
||||
list(filter(None, args['model_name_or_path'].split("/"))).pop(),
|
||||
str(args['max_seq_length'])))
|
||||
if os.path.exists(cached_features_file) and not args['overwrite_cache']:
|
||||
cached_features_file = os.path.join(
|
||||
args["data_dir"],
|
||||
"cached_{}_{}_{}.tf_record".format(
|
||||
mode, list(filter(None, args["model_name_or_path"].split("/"))).pop(), str(args["max_seq_length"])
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
|
||||
logging.info("Loading features from cached file %s", cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
||||
dataset, size = load_cache(cached_features_file, args["max_seq_length"])
|
||||
else:
|
||||
logging.info("Creating features from dataset file at %s", args['data_dir'])
|
||||
examples = read_examples_from_file(args['data_dir'], mode)
|
||||
features = convert_examples_to_features(examples, labels, args['max_seq_length'], tokenizer,
|
||||
cls_token_at_end=bool(args['model_type'] in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args['model_type'] in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args['model_type'] in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args['model_type'] in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args['model_type'] in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id
|
||||
)
|
||||
logging.info("Creating features from dataset file at %s", args["data_dir"])
|
||||
examples = read_examples_from_file(args["data_dir"], mode)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
args["max_seq_length"],
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args["model_type"] in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args["model_type"] in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
)
|
||||
logging.info("Saving features into cached file %s", cached_features_file)
|
||||
save_cache(features, cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
||||
dataset, size = load_cache(cached_features_file, args["max_seq_length"])
|
||||
|
||||
if mode == 'train':
|
||||
if mode == "train":
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.shuffle(buffer_size=8192, seed=args['seed'])
|
||||
dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder)
|
||||
dataset = dataset.prefetch(buffer_size=batch_size)
|
||||
@@ -466,83 +454,117 @@ def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
args = flags.FLAGS.flag_values_dict()
|
||||
|
||||
if os.path.exists(args['output_dir']) and os.listdir(
|
||||
args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
|
||||
if (
|
||||
os.path.exists(args["output_dir"])
|
||||
and os.listdir(args["output_dir"])
|
||||
and args["do_train"]
|
||||
and not args["overwrite_output_dir"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args['output_dir']))
|
||||
args["output_dir"]
|
||||
)
|
||||
)
|
||||
|
||||
if args['fp16']:
|
||||
if args["fp16"]:
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
||||
|
||||
if args['tpu']:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args['tpu'])
|
||||
if args["tpu"]:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||
args['n_device'] = args['num_tpu_cores']
|
||||
elif len(args['gpus'].split(',')) > 1:
|
||||
args['n_device'] = len([f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
||||
elif args['no_cuda']:
|
||||
args['n_device'] = 1
|
||||
args["n_device"] = args["num_tpu_cores"]
|
||||
elif len(args["gpus"].split(",")) > 1:
|
||||
args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
||||
elif args["no_cuda"]:
|
||||
args["n_device"] = 1
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
else:
|
||||
args['n_device'] = len(args['gpus'].split(','))
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args['gpus'].split(',')[0])
|
||||
args["n_device"] = len(args["gpus"].split(","))
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
|
||||
|
||||
logging.warning("n_device: %s, distributed training: %s, 16-bits training: %s",
|
||||
args['n_device'], bool(args['n_device'] > 1), args['fp16'])
|
||||
logging.warning(
|
||||
"n_device: %s, distributed training: %s, 16-bits training: %s",
|
||||
args["n_device"],
|
||||
bool(args["n_device"] > 1),
|
||||
args["fp16"],
|
||||
)
|
||||
|
||||
labels = get_labels(args['labels'])
|
||||
labels = get_labels(args["labels"])
|
||||
num_labels = len(labels) + 1
|
||||
pad_token_label_id = 0
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]
|
||||
config = config_class.from_pretrained(args['config_name'] if args['config_name'] else args['model_name_or_path'],
|
||||
num_labels=num_labels,
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
||||
config = config_class.from_pretrained(
|
||||
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
||||
num_labels=num_labels,
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
)
|
||||
|
||||
logging.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args['do_train']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['tokenizer_name'] if args['tokenizer_name'] else args['model_name_or_path'],
|
||||
do_lower_case=args['do_lower_case'],
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
if args["do_train"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
||||
do_lower_case=args["do_lower_case"],
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
)
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(args['model_name_or_path'],
|
||||
from_pt=bool(".bin" in args['model_name_or_path']),
|
||||
config=config,
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
model = model_class.from_pretrained(
|
||||
args["model_name_or_path"],
|
||||
from_pt=bool(".bin" in args["model_name_or_path"]),
|
||||
config=config,
|
||||
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||
)
|
||||
model.layers[-1].activation = tf.keras.activations.softmax
|
||||
|
||||
train_batch_size = args['per_device_train_batch_size'] * args['n_device']
|
||||
train_dataset, num_train_examples = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train")
|
||||
train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
|
||||
train_dataset, num_train_examples = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
|
||||
)
|
||||
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
||||
train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id)
|
||||
train(
|
||||
args,
|
||||
strategy,
|
||||
train_dataset,
|
||||
tokenizer,
|
||||
model,
|
||||
num_train_examples,
|
||||
labels,
|
||||
train_batch_size,
|
||||
pad_token_label_id,
|
||||
)
|
||||
|
||||
if not os.path.exists(args['output_dir']):
|
||||
os.makedirs(args['output_dir'])
|
||||
if not os.path.exists(args["output_dir"]):
|
||||
os.makedirs(args["output_dir"])
|
||||
|
||||
logging.info("Saving model to %s", args['output_dir'])
|
||||
logging.info("Saving model to %s", args["output_dir"])
|
||||
|
||||
model.save_pretrained(args['output_dir'])
|
||||
tokenizer.save_pretrained(args['output_dir'])
|
||||
model.save_pretrained(args["output_dir"])
|
||||
tokenizer.save_pretrained(args["output_dir"])
|
||||
|
||||
# Evaluation
|
||||
if args['do_eval']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
if args["do_eval"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
checkpoints = []
|
||||
results = []
|
||||
|
||||
if args['eval_all_checkpoints']:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
|
||||
if args["eval_all_checkpoints"]:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c)
|
||||
for c in sorted(
|
||||
glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
|
||||
key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
checkpoints.append(args['output_dir'])
|
||||
checkpoints.append(args["output_dir"])
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
@@ -550,13 +572,15 @@ def main(_):
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
|
||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
y_true, y_pred, eval_loss = evaluate(
|
||||
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||
)
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
if global_step:
|
||||
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
||||
|
||||
output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
|
||||
output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")
|
||||
|
||||
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
||||
for res in results:
|
||||
@@ -572,14 +596,16 @@ def main(_):
|
||||
writer.write(report)
|
||||
writer.write("\n")
|
||||
|
||||
if args['do_predict']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
model = model_class.from_pretrained(args['output_dir'])
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
|
||||
if args["do_predict"]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||
model = model_class.from_pretrained(args["output_dir"])
|
||||
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||
predict_dataset, _ = load_and_cache_examples(
|
||||
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
||||
)
|
||||
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
|
||||
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
|
||||
output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
|
||||
output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
||||
@@ -591,7 +617,7 @@ def main(_):
|
||||
writer.write("\n\nloss = " + str(pred_loss))
|
||||
|
||||
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
||||
with tf.io.gfile.GFile(os.path.join(args['data_dir'], "test.txt"), "r") as f:
|
||||
with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
|
||||
for line in f:
|
||||
|
||||
@@ -26,8 +26,7 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -37,10 +36,18 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME,
|
||||
BertConfig, BertForSequenceClassification, BertTokenizer,
|
||||
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
||||
DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForSequenceClassification,
|
||||
XLMTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
@@ -52,12 +59,14 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@@ -85,19 +94,26 @@ def train(args, train_dataset, model, tokenizer):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
@@ -112,17 +128,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -132,7 +152,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
@@ -143,7 +163,9 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
@@ -155,16 +177,16 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'labels': batch[3]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert"] else None
|
||||
) # XLM and DistilBERT don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -188,28 +210,32 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -258,11 +284,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'labels': batch[3]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert"] else None
|
||||
) # XLM and DistilBERT don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@@ -270,16 +296,16 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
else:
|
||||
raise ValueError('No other `output_mode` for XNLI.')
|
||||
raise ValueError("No other `output_mode` for XNLI.")
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
@@ -300,27 +326,34 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
processor = processors[task](language=args.language, train_language=args.train_language)
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format(
|
||||
'test' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
str(args.train_language if (not evaluate and args.train_language is not None) else args.language)))
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}_{}".format(
|
||||
"test" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
str(args.train_language if (not evaluate and args.train_language is not None) else args.language),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
examples = processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
features = convert_examples_to_features(examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
pad_on_left=False,
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=0,
|
||||
examples = (
|
||||
processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
pad_on_left=False,
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=0,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
@@ -336,7 +369,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
else:
|
||||
raise ValueError('No other `output_mode` for XNLI.')
|
||||
raise ValueError("No other `output_mode` for XNLI.")
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
@@ -346,92 +379,152 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--language", default=None, type=str, required=True,
|
||||
help="Evaluation language. Also train language if `train_language` is set to None.")
|
||||
parser.add_argument("--train_language", default=None, type=str,
|
||||
help="Train language if is different of the evaluation language.")
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Evaluation language. Also train language if `train_language` is set to None.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_language", default=None, type=str, help="Train language if is different of the evaluation language."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the test set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -443,22 +536,30 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare XNLI task
|
||||
args.task_name = 'xnli'
|
||||
args.task_name = "xnli"
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name](language=args.language, train_language=args.train_language)
|
||||
@@ -472,17 +573,23 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -491,14 +598,12 @@ def main():
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -508,36 +613,39 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
@@ -34,12 +34,30 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
||||
[
|
||||
"temp_dir",
|
||||
"large",
|
||||
"use_bert_emb",
|
||||
"finetune_bert",
|
||||
"encoder",
|
||||
"share_emb",
|
||||
"max_pos",
|
||||
"enc_layers",
|
||||
"enc_hidden_size",
|
||||
"enc_heads",
|
||||
"enc_ff_size",
|
||||
"enc_dropout",
|
||||
"dec_layers",
|
||||
"dec_hidden_size",
|
||||
"dec_heads",
|
||||
"dec_ff_size",
|
||||
"dec_dropout",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -119,7 +137,9 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
||||
output_converted_model = new_model(
|
||||
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||
)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
@@ -136,28 +156,21 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
|
||||
torch.save(
|
||||
new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
"--bertabs_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.bertabs_checkpoint_path, args.pytorch_dump_folder_path,
|
||||
)
|
||||
|
||||
@@ -56,40 +56,22 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
||||
if load_bert_pretrained_extractive:
|
||||
self.bert.model.load_state_dict(
|
||||
dict(
|
||||
[
|
||||
(n[11:], p)
|
||||
for n, p in bert_extractive_checkpoint.items()
|
||||
if n.startswith("bert.model")
|
||||
]
|
||||
),
|
||||
dict([(n[11:], p) for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")]),
|
||||
strict=True,
|
||||
)
|
||||
|
||||
self.vocab_size = self.bert.model.config.vocab_size
|
||||
|
||||
if args.max_pos > 512:
|
||||
my_pos_embeddings = nn.Embedding(
|
||||
args.max_pos, self.bert.model.config.hidden_size
|
||||
)
|
||||
my_pos_embeddings.weight.data[
|
||||
:512
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
my_pos_embeddings.weight.data[
|
||||
512:
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
|
||||
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||
None, :
|
||||
].repeat(
|
||||
args.max_pos - 512, 1
|
||||
)
|
||||
].repeat(args.max_pos - 512, 1)
|
||||
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
||||
tgt_embeddings = nn.Embedding(
|
||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
||||
)
|
||||
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
|
||||
|
||||
tgt_embeddings.weight = copy.deepcopy(
|
||||
self.bert.model.embeddings.word_embeddings.weight
|
||||
)
|
||||
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
|
||||
|
||||
self.decoder = TransformerDecoder(
|
||||
self.args.dec_layers,
|
||||
@@ -102,9 +84,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
)
|
||||
|
||||
gen_func = nn.LogSoftmax(dim=-1)
|
||||
self.generator = nn.Sequential(
|
||||
nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func
|
||||
)
|
||||
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func)
|
||||
self.generator[0].weight = self.decoder.embeddings.weight
|
||||
|
||||
load_from_checkpoints = False if checkpoint is None else True
|
||||
@@ -127,25 +107,14 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
p.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_input_ids,
|
||||
decoder_input_ids,
|
||||
token_type_ids,
|
||||
encoder_attention_mask,
|
||||
decoder_attention_mask,
|
||||
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
|
||||
):
|
||||
encoder_output = self.bert(
|
||||
input_ids=encoder_input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=encoder_attention_mask,
|
||||
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
|
||||
)
|
||||
encoder_hidden_states = encoder_output[0]
|
||||
dec_state = self.decoder.init_decoder_state(
|
||||
encoder_input_ids, encoder_hidden_states
|
||||
)
|
||||
decoder_outputs, _ = self.decoder(
|
||||
decoder_input_ids[:, :-1], encoder_hidden_states, dec_state
|
||||
)
|
||||
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
||||
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state)
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
@@ -162,10 +131,7 @@ class Bert(nn.Module):
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
encoder_outputs, _ = self.model(
|
||||
input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs
|
||||
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs
|
||||
)
|
||||
return encoder_outputs
|
||||
|
||||
@@ -196,10 +162,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
# Build TransformerDecoder.
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
[
|
||||
TransformerDecoderLayer(d_model, heads, d_ff, dropout)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
@@ -236,20 +199,14 @@ class TransformerDecoder(nn.Module):
|
||||
# Decoder padding mask
|
||||
tgt_words = tgt
|
||||
tgt_batch, tgt_len = tgt_words.size()
|
||||
tgt_pad_mask = (
|
||||
tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||
)
|
||||
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||
|
||||
# Encoder padding mask
|
||||
if memory_mask is not None:
|
||||
src_len = memory_mask.size(-1)
|
||||
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
||||
else:
|
||||
src_pad_mask = (
|
||||
src_words.data.eq(padding_idx)
|
||||
.unsqueeze(1)
|
||||
.expand(src_batch, tgt_len, src_len)
|
||||
)
|
||||
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len)
|
||||
|
||||
# Pass through the embeddings
|
||||
emb = self.embeddings(input_ids)
|
||||
@@ -271,9 +228,7 @@ class TransformerDecoder(nn.Module):
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=prev_layer_input,
|
||||
layer_cache=state.cache["layer_{}".format(i)]
|
||||
if state.cache is not None
|
||||
else None,
|
||||
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None,
|
||||
step=step,
|
||||
)
|
||||
if state.cache is None:
|
||||
@@ -303,9 +258,7 @@ class PositionalEncoding(nn.Module):
|
||||
def __init__(self, dropout, dim, max_len=5000):
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
)
|
||||
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
@@ -356,14 +309,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
memory_bank,
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=None,
|
||||
layer_cache=None,
|
||||
step=None,
|
||||
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -380,34 +326,20 @@ class TransformerDecoderLayer(nn.Module):
|
||||
* all_input `[batch_size x current_step x model_dim]`
|
||||
|
||||
"""
|
||||
dec_mask = torch.gt(
|
||||
tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0
|
||||
)
|
||||
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0)
|
||||
input_norm = self.layer_norm_1(inputs)
|
||||
all_input = input_norm
|
||||
if previous_input is not None:
|
||||
all_input = torch.cat((previous_input, input_norm), dim=1)
|
||||
dec_mask = None
|
||||
|
||||
query = self.self_attn(
|
||||
all_input,
|
||||
all_input,
|
||||
input_norm,
|
||||
mask=dec_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="self",
|
||||
)
|
||||
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
|
||||
|
||||
query = self.drop(query) + inputs
|
||||
|
||||
query_norm = self.layer_norm_2(query)
|
||||
mid = self.context_attn(
|
||||
memory_bank,
|
||||
memory_bank,
|
||||
query_norm,
|
||||
mask=src_pad_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="context",
|
||||
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
|
||||
)
|
||||
output = self.feed_forward(self.drop(mid) + query)
|
||||
|
||||
@@ -492,14 +424,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
self.final_linear = nn.Linear(model_dim, model_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
mask=None,
|
||||
layer_cache=None,
|
||||
type=None,
|
||||
predefined_graph_1=None,
|
||||
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
|
||||
):
|
||||
"""
|
||||
Compute the context vector and the attention vectors.
|
||||
@@ -531,11 +456,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
return (
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, -1, head_count * dim_per_head)
|
||||
)
|
||||
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
||||
|
||||
# 1) Project key, value, and query.
|
||||
if layer_cache is not None:
|
||||
@@ -554,9 +475,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
if layer_cache["self_keys"] is not None:
|
||||
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
||||
if layer_cache["self_values"] is not None:
|
||||
value = torch.cat(
|
||||
(layer_cache["self_values"].to(device), value), dim=2
|
||||
)
|
||||
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2)
|
||||
layer_cache["self_keys"] = key
|
||||
layer_cache["self_values"] = value
|
||||
elif type == "context":
|
||||
@@ -637,13 +556,9 @@ class DecoderState(object):
|
||||
sizes = e.size()
|
||||
br = sizes[1]
|
||||
if len(sizes) == 3:
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[
|
||||
:, :, idx
|
||||
]
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx]
|
||||
else:
|
||||
sent_states = e.view(
|
||||
sizes[0], beam_size, br // beam_size, sizes[2], sizes[3]
|
||||
)[:, :, idx]
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx]
|
||||
|
||||
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
||||
|
||||
@@ -716,11 +631,7 @@ class TransformerDecoderState(DecoderState):
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
@@ -758,9 +669,7 @@ class PositionwiseFeedForward(nn.Module):
|
||||
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
||||
# we should be able to refactor the global scorer a lot
|
||||
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
||||
translator = Translator(
|
||||
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger
|
||||
)
|
||||
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
|
||||
return translator
|
||||
|
||||
|
||||
@@ -891,9 +800,7 @@ class Translator(object):
|
||||
Shouldn't need the original dataset.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self._fast_translate_batch(
|
||||
batch, self.max_length, min_length=self.min_length
|
||||
)
|
||||
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
||||
|
||||
# Where the beam search lives
|
||||
# I have no idea why it is being called from the method above
|
||||
@@ -912,26 +819,18 @@ class Translator(object):
|
||||
mask_src = batch.mask_src
|
||||
|
||||
src_features = self.model.bert(src, segs, mask_src)
|
||||
dec_states = self.model.decoder.init_decoder_state(
|
||||
src, src_features, with_cache=True
|
||||
)
|
||||
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||
device = src_features.device
|
||||
|
||||
# Tile states and memory beam_size times.
|
||||
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
||||
src_features = tile(src_features, beam_size, dim=0)
|
||||
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
||||
beam_offset = torch.arange(
|
||||
0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device
|
||||
)
|
||||
alive_seq = torch.full(
|
||||
[batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device
|
||||
)
|
||||
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
|
||||
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
|
||||
|
||||
# Give full probability to the first beam on the first step.
|
||||
topk_log_probs = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (beam_size - 1), device=device
|
||||
).repeat(batch_size)
|
||||
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
|
||||
|
||||
# Structure that holds finished hypotheses.
|
||||
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
||||
@@ -948,9 +847,7 @@ class Translator(object):
|
||||
# Decoder forward.
|
||||
decoder_input = decoder_input.transpose(0, 1)
|
||||
|
||||
dec_out, dec_states = self.model.decoder(
|
||||
decoder_input, src_features, dec_states, step=step
|
||||
)
|
||||
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
||||
|
||||
# Generator forward.
|
||||
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
||||
@@ -978,10 +875,7 @@ class Translator(object):
|
||||
words = " ".join(words).replace(" ##", "").split()
|
||||
if len(words) <= 3:
|
||||
continue
|
||||
trigrams = [
|
||||
(words[i - 1], words[i], words[i + 1])
|
||||
for i in range(1, len(words) - 1)
|
||||
]
|
||||
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
|
||||
trigram = tuple(trigrams[-1])
|
||||
if trigram in trigrams[:-1]:
|
||||
fail = True
|
||||
@@ -999,15 +893,11 @@ class Translator(object):
|
||||
topk_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Map beam_index to batch_index in the flat representation.
|
||||
batch_index = topk_beam_index + beam_offset[
|
||||
: topk_beam_index.size(0)
|
||||
].unsqueeze(1)
|
||||
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
||||
select_indices = batch_index.view(-1)
|
||||
|
||||
# Append last prediction.
|
||||
alive_seq = torch.cat(
|
||||
[alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1
|
||||
)
|
||||
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
|
||||
|
||||
is_finished = topk_ids.eq(self.end_token)
|
||||
if step + 1 == max_length:
|
||||
@@ -1040,15 +930,11 @@ class Translator(object):
|
||||
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
||||
batch_index = batch_index.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
alive_seq = predictions.index_select(0, non_finished).view(
|
||||
-1, alive_seq.size(-1)
|
||||
)
|
||||
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
|
||||
# Reorder states.
|
||||
select_indices = batch_index.view(-1)
|
||||
src_features = src_features.index_select(0, select_indices)
|
||||
dec_states.map_batch_fn(
|
||||
lambda state, dim: state.index_select(dim, select_indices)
|
||||
)
|
||||
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
||||
|
||||
return results
|
||||
|
||||
@@ -1089,14 +975,7 @@ def tile(x, count, dim=0):
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = (
|
||||
x.view(batch, -1)
|
||||
.transpose(0, 1)
|
||||
.repeat(count, 1)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(*out_size)
|
||||
)
|
||||
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
@@ -1107,6 +986,7 @@ def tile(x, count, dim=0):
|
||||
# a finetuning script.
|
||||
#
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
@@ -1126,16 +1006,10 @@ class BertSumOptimizer(object):
|
||||
|
||||
self.optimizers = {
|
||||
"encoder": torch.optim.Adam(
|
||||
model.encoder.parameters(),
|
||||
lr=lr["encoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
),
|
||||
"decoder": torch.optim.Adam(
|
||||
model.decoder.parameters(),
|
||||
lr=lr["decoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1143,9 +1017,7 @@ class BertSumOptimizer(object):
|
||||
self.current_learning_rates = {}
|
||||
|
||||
def _update_rate(self, stack):
|
||||
return self.lr[stack] * min(
|
||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)
|
||||
)
|
||||
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5))
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer_decoder.zero_grad()
|
||||
|
||||
@@ -25,9 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
|
||||
)
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
@@ -48,13 +46,14 @@ def evaluate(args):
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=['rouge-n', 'rouge-l'],
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type='words',
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
@@ -161,15 +160,15 @@ Recall >> {:.3f}
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores['rouge-1']['f'],
|
||||
scores['rouge-1']['p'],
|
||||
scores['rouge-1']['r'],
|
||||
scores['rouge-2']['f'],
|
||||
scores['rouge-2']['p'],
|
||||
scores['rouge-2']['r'],
|
||||
scores['rouge-l']['f'],
|
||||
scores['rouge-l']['p'],
|
||||
scores['rouge-l']['r'],
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
@@ -187,9 +186,7 @@ def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||
iterator = DataLoader(
|
||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||
)
|
||||
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
|
||||
|
||||
return iterator
|
||||
|
||||
@@ -210,14 +207,9 @@ def collate(data, tokenizer, block_size, device):
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [
|
||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||
]
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||
for story, _ in encoded_text
|
||||
]
|
||||
[fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
@@ -272,38 +264,23 @@ def main():
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
|
||||
@@ -68,9 +68,7 @@ def process_story(raw_story):
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
@@ -135,13 +133,9 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [
|
||||
token for sentence in story_lines_token_ids for token in sentence
|
||||
]
|
||||
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [
|
||||
token for sentence in summary_lines_token_ids for token in sentence
|
||||
]
|
||||
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
@@ -33,25 +33,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights returns an empty list for the summary.
|
||||
@@ -95,9 +89,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(
|
||||
build_mask(sequence, 23).numpy(), expected.numpy()
|
||||
)
|
||||
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
@@ -106,12 +98,8 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor(
|
||||
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
|
||||
)
|
||||
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
@@ -35,34 +35,36 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-f')
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
class ExamplesTests(unittest.TestCase):
|
||||
|
||||
class ExamplesTests(unittest.TestCase):
|
||||
def test_run_glue(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = ["run_glue.py",
|
||||
"--data_dir=./examples/tests_samples/MRPC/",
|
||||
"--task_name=mrpc",
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
"--output_dir=./examples/tests_samples/temp_dir",
|
||||
"--per_gpu_train_batch_size=2",
|
||||
"--per_gpu_eval_batch_size=1",
|
||||
"--learning_rate=1e-4",
|
||||
"--max_steps=10",
|
||||
"--warmup_steps=2",
|
||||
"--overwrite_output_dir",
|
||||
"--seed=42"]
|
||||
model_type, model_name = ("--model_type=bert",
|
||||
"--model_name_or_path=bert-base-uncased")
|
||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||
testargs = [
|
||||
"run_glue.py",
|
||||
"--data_dir=./examples/tests_samples/MRPC/",
|
||||
"--task_name=mrpc",
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
"--output_dir=./examples/tests_samples/temp_dir",
|
||||
"--per_gpu_train_batch_size=2",
|
||||
"--per_gpu_eval_batch_size=1",
|
||||
"--learning_rate=1e-4",
|
||||
"--max_steps=10",
|
||||
"--warmup_steps=2",
|
||||
"--overwrite_output_dir",
|
||||
"--seed=42",
|
||||
]
|
||||
model_type, model_name = ("--model_type=bert", "--model_name_or_path=bert-base-uncased")
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_glue.main()
|
||||
for value in result.values():
|
||||
self.assertGreaterEqual(value, 0.75)
|
||||
@@ -71,40 +73,38 @@ class ExamplesTests(unittest.TestCase):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = ["run_squad.py",
|
||||
"--data_dir=./examples/tests_samples/SQUAD",
|
||||
"--model_name=bert-base-uncased",
|
||||
"--output_dir=./examples/tests_samples/temp_dir",
|
||||
"--max_steps=10",
|
||||
"--warmup_steps=2",
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
"--version_2_with_negative",
|
||||
"--learning_rate=2e-4",
|
||||
"--per_gpu_train_batch_size=2",
|
||||
"--per_gpu_eval_batch_size=1",
|
||||
"--overwrite_output_dir",
|
||||
"--seed=42"]
|
||||
model_type, model_name = ("--model_type=bert",
|
||||
"--model_name_or_path=bert-base-uncased")
|
||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||
testargs = [
|
||||
"run_squad.py",
|
||||
"--data_dir=./examples/tests_samples/SQUAD",
|
||||
"--model_name=bert-base-uncased",
|
||||
"--output_dir=./examples/tests_samples/temp_dir",
|
||||
"--max_steps=10",
|
||||
"--warmup_steps=2",
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
"--version_2_with_negative",
|
||||
"--learning_rate=2e-4",
|
||||
"--per_gpu_train_batch_size=2",
|
||||
"--per_gpu_eval_batch_size=1",
|
||||
"--overwrite_output_dir",
|
||||
"--seed=42",
|
||||
]
|
||||
model_type, model_name = ("--model_type=bert", "--model_name_or_path=bert-base-uncased")
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_squad.main()
|
||||
self.assertGreaterEqual(result['f1'], 30)
|
||||
self.assertGreaterEqual(result['exact'], 30)
|
||||
self.assertGreaterEqual(result["f1"], 30)
|
||||
self.assertGreaterEqual(result["exact"], 30)
|
||||
|
||||
def test_generation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = ["run_generation.py",
|
||||
"--prompt=Hello",
|
||||
"--length=10",
|
||||
"--seed=42"]
|
||||
model_type, model_name = ("--model_type=openai-gpt",
|
||||
"--model_name_or_path=openai-gpt")
|
||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
||||
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result), 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -55,19 +55,10 @@ class InputExample(object):
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(self,
|
||||
example_id,
|
||||
choices_features,
|
||||
label
|
||||
|
||||
):
|
||||
def __init__(self, example_id, choices_features, label):
|
||||
self.example_id = example_id
|
||||
self.choices_features = [
|
||||
{
|
||||
'input_ids': input_ids,
|
||||
'input_mask': input_mask,
|
||||
'segment_ids': segment_ids
|
||||
}
|
||||
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
|
||||
for input_ids, input_mask, segment_ids in choices_features
|
||||
]
|
||||
self.label = label
|
||||
@@ -99,29 +90,29 @@ class RaceProcessor(DataProcessor):
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
high = os.path.join(data_dir, 'train/high')
|
||||
middle = os.path.join(data_dir, 'train/middle')
|
||||
high = os.path.join(data_dir, "train/high")
|
||||
middle = os.path.join(data_dir, "train/middle")
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, 'train')
|
||||
return self._create_examples(high + middle, "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
high = os.path.join(data_dir, 'dev/high')
|
||||
middle = os.path.join(data_dir, 'dev/middle')
|
||||
high = os.path.join(data_dir, "dev/high")
|
||||
middle = os.path.join(data_dir, "dev/middle")
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, 'dev')
|
||||
return self._create_examples(high + middle, "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} test".format(data_dir))
|
||||
high = os.path.join(data_dir, 'test/high')
|
||||
middle = os.path.join(data_dir, 'test/middle')
|
||||
high = os.path.join(data_dir, "test/high")
|
||||
middle = os.path.join(data_dir, "test/middle")
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, 'test')
|
||||
return self._create_examples(high + middle, "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -131,13 +122,12 @@ class RaceProcessor(DataProcessor):
|
||||
lines = []
|
||||
files = glob.glob(input_dir + "/*txt")
|
||||
for file in tqdm.tqdm(files, desc="read files"):
|
||||
with open(file, 'r', encoding='utf-8') as fin:
|
||||
with open(file, "r", encoding="utf-8") as fin:
|
||||
data_raw = json.load(fin)
|
||||
data_raw["race_id"] = file
|
||||
lines.append(data_raw)
|
||||
return lines
|
||||
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
@@ -145,19 +135,22 @@ class RaceProcessor(DataProcessor):
|
||||
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
||||
article = data_raw["article"]
|
||||
for i in range(len(data_raw["answers"])):
|
||||
truth = str(ord(data_raw['answers'][i]) - ord('A'))
|
||||
question = data_raw['questions'][i]
|
||||
options = data_raw['options'][i]
|
||||
truth = str(ord(data_raw["answers"][i]) - ord("A"))
|
||||
question = data_raw["questions"][i]
|
||||
options = data_raw["options"][i]
|
||||
|
||||
examples.append(
|
||||
InputExample(
|
||||
example_id=race_id,
|
||||
question=question,
|
||||
contexts=[article, article, article, article], # this is not efficient but convenient
|
||||
contexts=[article, article, article, article], # this is not efficient but convenient
|
||||
endings=[options[0], options[1], options[2], options[3]],
|
||||
label=truth))
|
||||
label=truth,
|
||||
)
|
||||
)
|
||||
return examples
|
||||
|
||||
|
||||
class SwagProcessor(DataProcessor):
|
||||
"""Processor for the SWAG data set."""
|
||||
|
||||
@@ -179,27 +172,25 @@ class SwagProcessor(DataProcessor):
|
||||
"setting!"
|
||||
)
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
line = list(unicode(cell, "utf-8") for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
if type == "train" and lines[0][-1] != 'label':
|
||||
raise ValueError(
|
||||
"For training, the input file must contain a label column."
|
||||
)
|
||||
if type == "train" and lines[0][-1] != "label":
|
||||
raise ValueError("For training, the input file must contain a label column.")
|
||||
|
||||
examples = [
|
||||
InputExample(
|
||||
@@ -207,10 +198,11 @@ class SwagProcessor(DataProcessor):
|
||||
question=line[5], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
contexts = [line[4], line[4], line[4], line[4]],
|
||||
endings = [line[7], line[8], line[9], line[10]],
|
||||
label=line[11]
|
||||
) for line in lines[1:] # we skip the line with the column names
|
||||
contexts=[line[4], line[4], line[4], line[4]],
|
||||
endings=[line[7], line[8], line[9], line[10]],
|
||||
label=line[11],
|
||||
)
|
||||
for line in lines[1:] # we skip the line with the column names
|
||||
]
|
||||
|
||||
return examples
|
||||
@@ -238,15 +230,14 @@ class ArcProcessor(DataProcessor):
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_json(self, input_file):
|
||||
with open(input_file, 'r', encoding='utf-8') as fin:
|
||||
with open(input_file, "r", encoding="utf-8") as fin:
|
||||
lines = fin.readlines()
|
||||
return lines
|
||||
|
||||
|
||||
def _create_examples(self, lines, type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
||||
#There are two types of labels. They should be normalized
|
||||
# There are two types of labels. They should be normalized
|
||||
def normalize(truth):
|
||||
if truth in "ABCD":
|
||||
return ord(truth) - ord("A")
|
||||
@@ -283,12 +274,18 @@ class ArcProcessor(DataProcessor):
|
||||
if len(options) == 4:
|
||||
examples.append(
|
||||
InputExample(
|
||||
example_id = id,
|
||||
example_id=id,
|
||||
question=question,
|
||||
contexts=[options[0]["para"].replace("_", ""), options[1]["para"].replace("_", ""),
|
||||
options[2]["para"].replace("_", ""), options[3]["para"].replace("_", "")],
|
||||
contexts=[
|
||||
options[0]["para"].replace("_", ""),
|
||||
options[1]["para"].replace("_", ""),
|
||||
options[2]["para"].replace("_", ""),
|
||||
options[3]["para"].replace("_", ""),
|
||||
],
|
||||
endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]],
|
||||
label=truth))
|
||||
label=truth,
|
||||
)
|
||||
)
|
||||
|
||||
if type == "train":
|
||||
assert len(examples) > 1
|
||||
@@ -316,7 +313,7 @@ def convert_examples_to_features(
|
||||
Loads a data file into a list of `InputFeatures`
|
||||
"""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
@@ -331,16 +328,13 @@ def convert_examples_to_features(
|
||||
else:
|
||||
text_b = example.question + " " + ending
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
text_a,
|
||||
text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0:
|
||||
logger.info('Attention! you are cropping tokens (swag task is ok). '
|
||||
'If you are training ARC and RACE and you are poping question + options,'
|
||||
'you need to try to use a bigger max seq length!')
|
||||
inputs = tokenizer.encode_plus(text_a, text_b, add_special_tokens=True, max_length=max_length,)
|
||||
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
||||
logger.info(
|
||||
"Attention! you are cropping tokens (swag task is ok). "
|
||||
"If you are training ARC and RACE and you are poping question + options,"
|
||||
"you need to try to use a bigger max seq length!"
|
||||
)
|
||||
|
||||
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||
|
||||
@@ -364,7 +358,6 @@ def convert_examples_to_features(
|
||||
assert len(token_type_ids) == max_length
|
||||
choices_features.append((input_ids, attention_mask, token_type_ids))
|
||||
|
||||
|
||||
label = label_map[example.label]
|
||||
|
||||
if ex_index < 2:
|
||||
@@ -372,33 +365,17 @@ def convert_examples_to_features(
|
||||
logger.info("race_id: {}".format(example.example_id))
|
||||
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
||||
logger.info("attention_mask: {}".format(' '.join(map(str, attention_mask))))
|
||||
logger.info("token_type_ids: {}".format(' '.join(map(str, token_type_ids))))
|
||||
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
||||
logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
|
||||
logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
|
||||
logger.info("label: {}".format(label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id=example.example_id,
|
||||
choices_features=choices_features,
|
||||
label=label,
|
||||
)
|
||||
)
|
||||
features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label,))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor}
|
||||
|
||||
|
||||
processors = {
|
||||
"race": RaceProcessor,
|
||||
"swag": SwagProcessor,
|
||||
"arc": ArcProcessor
|
||||
}
|
||||
|
||||
|
||||
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {
|
||||
"race", 4,
|
||||
"swag", 4,
|
||||
"arc", 4
|
||||
}
|
||||
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4}
|
||||
|
||||
@@ -61,9 +61,7 @@ def read_examples_from_file(data_dir, mode):
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
if words:
|
||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
|
||||
words=words,
|
||||
labels=labels))
|
||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words, labels=labels))
|
||||
guid_index += 1
|
||||
words = []
|
||||
labels = []
|
||||
@@ -76,27 +74,27 @@ def read_examples_from_file(data_dir, mode):
|
||||
# Examples could have no label for mode = "test"
|
||||
labels.append("O")
|
||||
if words:
|
||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
|
||||
words=words,
|
||||
labels=labels))
|
||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index), words=words, labels=labels))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
cls_token_segment_id=1,
|
||||
sep_token="[SEP]",
|
||||
sep_token_extra=False,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
pad_token_label_id=-100,
|
||||
sequence_a_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
def convert_examples_to_features(
|
||||
examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
cls_token_segment_id=1,
|
||||
sep_token="[SEP]",
|
||||
sep_token_extra=False,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
pad_token_label_id=-100,
|
||||
sequence_a_segment_id=0,
|
||||
mask_padding_with_zero=True,
|
||||
):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
@@ -122,8 +120,8 @@ def convert_examples_to_features(examples,
|
||||
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
||||
special_tokens_count = 3 if sep_token_extra else 2
|
||||
if len(tokens) > max_seq_length - special_tokens_count:
|
||||
tokens = tokens[:(max_seq_length - special_tokens_count)]
|
||||
label_ids = label_ids[:(max_seq_length - special_tokens_count)]
|
||||
tokens = tokens[: (max_seq_length - special_tokens_count)]
|
||||
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
@@ -174,10 +172,10 @@ def convert_examples_to_features(examples,
|
||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
||||
else:
|
||||
input_ids += ([pad_token] * padding_length)
|
||||
input_mask += ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
segment_ids += ([pad_token_segment_id] * padding_length)
|
||||
label_ids += ([pad_token_label_id] * padding_length)
|
||||
input_ids += [pad_token] * padding_length
|
||||
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||||
segment_ids += [pad_token_segment_id] * padding_length
|
||||
label_ids += [pad_token_label_id] * padding_length
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
@@ -194,10 +192,8 @@ def convert_examples_to_features(examples,
|
||||
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_ids=label_ids))
|
||||
InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@@ -209,4 +205,4 @@ def get_labels(path):
|
||||
labels = ["O"] + labels
|
||||
return labels
|
||||
else:
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
|
||||
11
hubconf.py
11
hubconf.py
@@ -1,9 +1,15 @@
|
||||
from transformers import (
|
||||
AutoTokenizer, AutoConfig, AutoModel, AutoModelWithLMHead, AutoModelForSequenceClassification, AutoModelForQuestionAnswering
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelWithLMHead,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
)
|
||||
from transformers.file_utils import add_start_docstrings
|
||||
|
||||
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex', 'sentencepiece', 'sacremoses']
|
||||
dependencies = ["torch", "tqdm", "boto3", "requests", "regex", "sentencepiece", "sacremoses"]
|
||||
|
||||
|
||||
@add_start_docstrings(AutoConfig.__doc__)
|
||||
def config(*args, **kwargs):
|
||||
@@ -57,6 +63,7 @@ def model(*args, **kwargs):
|
||||
|
||||
return AutoModel.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
@add_start_docstrings(AutoModelWithLMHead.__doc__)
|
||||
def modelWithLMHead(*args, **kwargs):
|
||||
r"""
|
||||
|
||||
47
setup.py
47
setup.py
@@ -38,11 +38,11 @@ from setuptools import find_packages, setup
|
||||
|
||||
|
||||
extras = {
|
||||
'serving': ['pydantic', 'uvicorn', 'fastapi'],
|
||||
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
|
||||
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
|
||||
"serving": ["pydantic", "uvicorn", "fastapi"],
|
||||
"serving-tf": ["pydantic", "uvicorn", "fastapi", "tensorflow"],
|
||||
"serving-torch": ["pydantic", "uvicorn", "fastapi", "torch"],
|
||||
}
|
||||
extras['all'] = [package for package in extras.values()]
|
||||
extras["all"] = [package for package in extras.values()]
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
@@ -50,30 +50,29 @@ setup(
|
||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords='NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU',
|
||||
license='Apache',
|
||||
keywords="NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU",
|
||||
license="Apache",
|
||||
url="https://github.com/huggingface/transformers",
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
||||
"tests.*", "tests"]),
|
||||
install_requires=['numpy',
|
||||
'boto3',
|
||||
'filelock',
|
||||
'requests',
|
||||
'tqdm',
|
||||
'regex != 2019.12.17',
|
||||
'sentencepiece',
|
||||
'sacremoses'],
|
||||
extras_require=extras,
|
||||
scripts=[
|
||||
'transformers-cli'
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"boto3",
|
||||
"filelock",
|
||||
"requests",
|
||||
"tqdm",
|
||||
"regex != 2019.12.17",
|
||||
"sentencepiece",
|
||||
"sacremoses",
|
||||
],
|
||||
extras_require=extras,
|
||||
scripts=["transformers-cli"],
|
||||
# python_requires='>=3.5.0',
|
||||
classifiers=[
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -24,8 +24,7 @@ import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -35,19 +34,32 @@ except:
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForQuestionAnswering,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
from utils_squad import (
|
||||
read_squad_examples,
|
||||
convert_examples_to_features,
|
||||
RawResult,
|
||||
write_predictions,
|
||||
RawResultExtended,
|
||||
write_predictions_extended,
|
||||
)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
@@ -56,16 +68,18 @@ from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -73,9 +87,11 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
@@ -92,13 +108,18 @@ def train(args, train_dataset, model, tokenizer):
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
@@ -112,17 +133,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -136,20 +161,21 @@ def train(args, train_dataset, model, tokenizer):
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"start_positions": batch[3],
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -173,22 +199,26 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
if (
|
||||
args.local_rank == -1 and args.evaluate_during_training
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@@ -224,32 +254,31 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(unique_id = unique_id,
|
||||
start_top_log_probs = to_list(outputs[0][i]),
|
||||
start_top_index = to_list(outputs[1][i]),
|
||||
end_top_log_probs = to_list(outputs[2][i]),
|
||||
end_top_index = to_list(outputs[3][i]),
|
||||
cls_logits = to_list(outputs[4][i]))
|
||||
result = RawResultExtended(
|
||||
unique_id=unique_id,
|
||||
start_top_log_probs=to_list(outputs[0][i]),
|
||||
start_top_index=to_list(outputs[1][i]),
|
||||
end_top_log_probs=to_list(outputs[2][i]),
|
||||
end_top_index=to_list(outputs[3][i]),
|
||||
cls_logits=to_list(outputs[4][i]),
|
||||
)
|
||||
else:
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
result = RawResult(
|
||||
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Compute predictions
|
||||
@@ -260,23 +289,44 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
write_predictions_extended(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.predict_file,
|
||||
model.config.start_n_top,
|
||||
model.config.end_n_top,
|
||||
args.version_2_with_negative,
|
||||
tokenizer,
|
||||
args.verbose_logging,
|
||||
)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
write_predictions(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
args.n_best_size,
|
||||
args.max_answer_length,
|
||||
args.do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
args.verbose_logging,
|
||||
args.version_2_with_negative,
|
||||
args.null_score_diff_threshold,
|
||||
)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
pred_file=output_prediction_file,
|
||||
na_prob_file=output_null_log_odds_file)
|
||||
evaluate_options = EVAL_OPTS(
|
||||
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
|
||||
)
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
return results
|
||||
|
||||
@@ -287,24 +337,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
cached_features_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
"cached_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate)
|
||||
examples = read_squad_examples(
|
||||
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
@@ -320,14 +376,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
|
||||
)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids,
|
||||
all_input_mask,
|
||||
all_segment_ids,
|
||||
all_start_positions,
|
||||
all_end_positions,
|
||||
all_cls_index,
|
||||
all_p_mask,
|
||||
)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
@@ -338,109 +401,190 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
parser.add_argument(
|
||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
|
||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
parser.add_argument(
|
||||
"--version_2_with_negative",
|
||||
action="store_true",
|
||||
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--null_score_diff_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||
)
|
||||
|
||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--doc_stride", default=128, type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
||||
parser.add_argument("--max_query_length", default=64, type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=384,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_stride",
|
||||
default=128,
|
||||
type=int,
|
||||
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_query_length",
|
||||
default=64,
|
||||
type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--n_best_size", default=20, type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.")
|
||||
parser.add_argument("--verbose_logging", action='store_true',
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--n_best_size",
|
||||
default=20,
|
||||
type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_answer_length",
|
||||
default=30,
|
||||
type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose_logging",
|
||||
action="store_true",
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
@@ -452,16 +596,24 @@ def main():
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank,
|
||||
device,
|
||||
args.n_gpu,
|
||||
bool(args.local_rank != -1),
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
@@ -472,15 +624,21 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
@@ -495,7 +653,8 @@ def main():
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
apex.amp.register_half_function(torch, 'einsum')
|
||||
|
||||
apex.amp.register_half_function(torch, "einsum")
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
@@ -505,7 +664,6 @@ def main():
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Save the trained model and the tokenizer
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
@@ -515,39 +673,42 @@ def main():
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||
)
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
||||
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2018 XXX. All rights reserved.
|
||||
#
|
||||
@@ -37,14 +36,16 @@ class SquadExample(object):
|
||||
For examples without an answer, the start and end position are -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
def __init__(
|
||||
self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None,
|
||||
):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.doc_tokens = doc_tokens
|
||||
@@ -59,8 +60,7 @@ class SquadExample(object):
|
||||
def __repr__(self):
|
||||
s = ""
|
||||
s += "qas_id: %s" % (self.qas_id)
|
||||
s += ", question_text: %s" % (
|
||||
self.question_text)
|
||||
s += ", question_text: %s" % (self.question_text)
|
||||
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||
if self.start_position:
|
||||
s += ", start_position: %d" % (self.start_position)
|
||||
@@ -74,22 +74,24 @@ class SquadExample(object):
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
def __init__(
|
||||
self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None,
|
||||
):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
@@ -109,7 +111,7 @@ class InputFeatures(object):
|
||||
|
||||
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
"""Read a SQuAD json file into a list of SquadExample."""
|
||||
with open(input_file, "r", encoding='utf-8') as reader:
|
||||
with open(input_file, "r", encoding="utf-8") as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
def is_whitespace(c):
|
||||
@@ -146,8 +148,7 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
if version_2_with_negative:
|
||||
is_impossible = qa["is_impossible"]
|
||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||
raise ValueError(
|
||||
"For training, each question should have exactly 1 answer.")
|
||||
raise ValueError("For training, each question should have exactly 1 answer.")
|
||||
if not is_impossible:
|
||||
answer = qa["answers"][0]
|
||||
orig_answer_text = answer["text"]
|
||||
@@ -161,12 +162,10 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
#
|
||||
# Note that this means for training mode, every example is NOT
|
||||
# guaranteed to be preserved.
|
||||
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
||||
cleaned_answer_text = " ".join(
|
||||
whitespace_tokenize(orig_answer_text))
|
||||
actual_text = " ".join(doc_tokens[start_position : (end_position + 1)])
|
||||
cleaned_answer_text = " ".join(whitespace_tokenize(orig_answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'",
|
||||
actual_text, cleaned_answer_text)
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||
continue
|
||||
else:
|
||||
start_position = -1
|
||||
@@ -180,18 +179,29 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=is_impossible)
|
||||
is_impossible=is_impossible,
|
||||
)
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training,
|
||||
cls_token_at_end=False,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
def convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
doc_stride,
|
||||
max_query_length,
|
||||
is_training,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
pad_token=0,
|
||||
sequence_a_segment_id=0,
|
||||
sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True,
|
||||
):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
@@ -232,8 +242,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
else:
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
||||
example.orig_answer_text)
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.orig_answer_text
|
||||
)
|
||||
|
||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||
@@ -241,8 +251,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
# We can have documents that are longer than the maximum sequence length.
|
||||
# To deal with this we do a sliding window approach, where we take chunks
|
||||
# of the up to our max length with a stride of `doc_stride`.
|
||||
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"DocSpan", ["start", "length"])
|
||||
_DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) # pylint: disable=invalid-name
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
@@ -287,8 +296,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
||||
split_token_index)
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
@@ -333,8 +341,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_start = doc_span.start
|
||||
doc_end = doc_span.start + doc_span.length - 1
|
||||
out_of_span = False
|
||||
if not (tok_start_position >= doc_start and
|
||||
tok_end_position <= doc_end):
|
||||
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
if out_of_span:
|
||||
start_position = 0
|
||||
@@ -355,24 +362,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
logger.info("example_index: %s" % (example_index))
|
||||
logger.info("doc_span_index: %s" % (doc_span_index))
|
||||
logger.info("tokens: %s" % " ".join(tokens))
|
||||
logger.info("token_to_orig_map: %s" % " ".join([
|
||||
"%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
|
||||
logger.info("token_is_max_context: %s" % " ".join([
|
||||
"%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
|
||||
]))
|
||||
logger.info(
|
||||
"token_to_orig_map: %s" % " ".join(["%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])
|
||||
)
|
||||
logger.info(
|
||||
"token_is_max_context: %s"
|
||||
% " ".join(["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()])
|
||||
)
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info(
|
||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info(
|
||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
if is_training and span_is_impossible:
|
||||
logger.info("impossible example")
|
||||
if is_training and not span_is_impossible:
|
||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||
answer_text = " ".join(tokens[start_position : (end_position + 1)])
|
||||
logger.info("start_position: %d" % (start_position))
|
||||
logger.info("end_position: %d" % (end_position))
|
||||
logger.info(
|
||||
"answer: %s" % (answer_text))
|
||||
logger.info("answer: %s" % (answer_text))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
@@ -390,14 +396,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
paragraph_len=paragraph_len,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=span_is_impossible))
|
||||
is_impossible=span_is_impossible,
|
||||
)
|
||||
)
|
||||
unique_id += 1
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
orig_answer_text):
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
|
||||
# The SQuAD annotations are character based. We first project them to
|
||||
@@ -426,7 +433,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
||||
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
||||
if text_span == tok_answer_text:
|
||||
return (new_start, new_end)
|
||||
|
||||
@@ -470,13 +477,23 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
RawResult = collections.namedtuple("RawResult",
|
||||
["unique_id", "start_logits", "end_logits"])
|
||||
RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
|
||||
|
||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||
version_2_with_negative, null_score_diff_threshold):
|
||||
|
||||
def write_predictions(
|
||||
all_examples,
|
||||
all_features,
|
||||
all_results,
|
||||
n_best_size,
|
||||
max_answer_length,
|
||||
do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
verbose_logging,
|
||||
version_2_with_negative,
|
||||
null_score_diff_threshold,
|
||||
):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
@@ -490,8 +507,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
||||
)
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
@@ -544,7 +561,9 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
end_logit=result.end_logits[end_index],
|
||||
)
|
||||
)
|
||||
if version_2_with_negative:
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
@@ -552,14 +571,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
start_index=0,
|
||||
end_index=0,
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_logit + x.end_logit),
|
||||
reverse=True)
|
||||
end_logit=null_end_logit,
|
||||
)
|
||||
)
|
||||
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
||||
)
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
@@ -568,10 +587,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
if pred.start_index > 0: # this is a non-null prediction
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
@@ -592,31 +611,21 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
final_text = ""
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit))
|
||||
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
||||
# if we didn't include the empty option in the n-best, include it
|
||||
if version_2_with_negative:
|
||||
if "" not in seen_predictions:
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text="",
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
||||
|
||||
# In very rare edge cases we could only have single null prediction.
|
||||
# So we just create a nonce prediction in this case to avoid failure.
|
||||
if len(nbest)==1:
|
||||
nbest.insert(0,
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
if len(nbest) == 1:
|
||||
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
assert len(nbest) >= 1
|
||||
|
||||
@@ -645,8 +654,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||
else:
|
||||
# predict "" iff the null score - the score of best non-null > threshold
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||
best_non_null_entry.end_logit)
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example.qas_id] = ""
|
||||
@@ -668,29 +676,40 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
|
||||
|
||||
# For XLNet (and XLM which uses the same head)
|
||||
RawResultExtended = collections.namedtuple("RawResultExtended",
|
||||
["unique_id", "start_top_log_probs", "start_top_index",
|
||||
"end_top_log_probs", "end_top_index", "cls_logits"])
|
||||
RawResultExtended = collections.namedtuple(
|
||||
"RawResultExtended",
|
||||
["unique_id", "start_top_log_probs", "start_top_index", "end_top_log_probs", "end_top_index", "cls_logits"],
|
||||
)
|
||||
|
||||
|
||||
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file, orig_data_file,
|
||||
start_n_top, end_n_top, version_2_with_negative,
|
||||
tokenizer, verbose_logging):
|
||||
def write_predictions_extended(
|
||||
all_examples,
|
||||
all_features,
|
||||
all_results,
|
||||
n_best_size,
|
||||
max_answer_length,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
orig_data_file,
|
||||
start_n_top,
|
||||
end_n_top,
|
||||
version_2_with_negative,
|
||||
tokenizer,
|
||||
verbose_logging,
|
||||
):
|
||||
""" XLNet write prediction logic (more complex than Bert's).
|
||||
Write final predictions to the json file and log-odds of null if needed.
|
||||
|
||||
Requires utils_squad_evaluate.py
|
||||
"""
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index",
|
||||
"start_log_prob", "end_log_prob"])
|
||||
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
||||
)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
||||
)
|
||||
|
||||
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
@@ -754,12 +773,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_log_prob=start_log_prob,
|
||||
end_log_prob=end_log_prob))
|
||||
end_log_prob=end_log_prob,
|
||||
)
|
||||
)
|
||||
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
||||
reverse=True)
|
||||
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
||||
)
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
@@ -779,10 +799,10 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||
|
||||
# Previously used Bert untokenizer
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
|
||||
# Clean whitespace
|
||||
@@ -790,8 +810,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
||||
verbose_logging)
|
||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, verbose_logging)
|
||||
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
@@ -799,17 +818,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_log_prob=pred.start_log_prob,
|
||||
end_log_prob=pred.end_log_prob))
|
||||
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
||||
)
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="", start_log_prob=-1e6,
|
||||
end_log_prob=-1e6))
|
||||
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
@@ -850,7 +865,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
||||
with open(output_null_log_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
||||
with open(orig_data_file, "r", encoding="utf-8") as reader:
|
||||
orig_data = json.load(reader)["data"]
|
||||
|
||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||
@@ -914,8 +929,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
if verbose_logging:
|
||||
logger.info(
|
||||
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||
logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
@@ -924,8 +938,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
if verbose_logging:
|
||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
||||
orig_ns_text, tok_ns_text)
|
||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
|
||||
return orig_text
|
||||
|
||||
# We then project the characters in `pred_text` back to `orig_text` using
|
||||
@@ -956,7 +969,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
logger.info("Couldn't map end position")
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -63,24 +63,26 @@ class XxxConfig(PretrainedConfig):
|
||||
"""
|
||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type='cls_index',
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs
|
||||
):
|
||||
super(XxxConfig, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.n_ctx = n_ctx
|
||||
|
||||
@@ -24,8 +24,10 @@ import torch
|
||||
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = XxxConfig.from_json_file(config_file)
|
||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.config_file,
|
||||
args.pytorch_dump_path)
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||
|
||||
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
||||
}
|
||||
|
||||
####################################################
|
||||
@@ -69,9 +69,9 @@ TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
class TFXxxLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFXxxLayer, self).__init__(**kwargs)
|
||||
self.attention = TFXxxAttention(config, name='attention')
|
||||
self.intermediate = TFXxxIntermediate(config, name='intermediate')
|
||||
self.transformer_output = TFXxxOutput(config, name='output')
|
||||
self.attention = TFXxxAttention(config, name="attention")
|
||||
self.intermediate = TFXxxIntermediate(config, name="intermediate")
|
||||
self.transformer_output = TFXxxOutput(config, name="output")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, head_mask = inputs
|
||||
@@ -98,7 +98,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(
|
||||
self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False
|
||||
):
|
||||
# We allow three types of multi-inputs:
|
||||
# - traditional keyword arguments in the call method
|
||||
# - all the arguments provided as a dict in the first positional argument of call
|
||||
@@ -113,11 +115,11 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
@@ -175,6 +177,7 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
@@ -263,8 +266,12 @@ XXX_INPUTS_DOCSTRING = r"""
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxModel(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -297,17 +304,19 @@ class TFXxxModel(TFXxxPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||
)
|
||||
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -333,26 +342,30 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
||||
prediction_scores = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name='mlm')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name="mlm")
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
|
||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -378,22 +391,23 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
logits = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
|
||||
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
@@ -401,9 +415,12 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||
return outputs # logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -429,22 +446,23 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
scores = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
|
||||
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
@@ -452,9 +470,12 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||
return outputs # scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -482,14 +503,15 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
||||
start_scores, end_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
||||
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='qa_outputs')
|
||||
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
||||
}
|
||||
|
||||
####################################################
|
||||
@@ -60,8 +60,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
@@ -76,7 +78,7 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
name = name.split("/")
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
@@ -84,18 +86,18 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||
l = re.split(r"_(\d+)", m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
if l[0] == "kernel" or l[0] == "gamma":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "output_bias" or l[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif l[0] == "output_weights":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
@@ -105,9 +107,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
if m_name[-11:] == "_embeddings":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "kernel":
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
@@ -147,7 +149,6 @@ class XxxLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
####################################################
|
||||
# PreTrainedModel is a sub-class of torch.nn.Module
|
||||
# which take care of loading and saving pretrained weights
|
||||
@@ -161,6 +162,7 @@ class XxxPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xxx
|
||||
@@ -246,8 +248,12 @@ XXX_INPUTS_DOCSTRING = r"""
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxModel(XxxPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -277,6 +283,7 @@ class XxxModel(XxxPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxModel, self).__init__(config)
|
||||
|
||||
@@ -300,7 +307,15 @@ class XxxModel(XxxPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -329,7 +344,7 @@ class XxxModel(XxxPreTrainedModel):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -342,14 +357,20 @@ class XxxModel(XxxPreTrainedModel):
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = (
|
||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||
) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
##################################
|
||||
# Replace this with your model code
|
||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
@@ -357,8 +378,9 @@ class XxxModel(XxxPreTrainedModel):
|
||||
return outputs # sequence_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||
)
|
||||
class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -389,6 +411,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForMaskedLM, self).__init__(config)
|
||||
|
||||
@@ -400,15 +423,25 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
@@ -422,9 +455,12 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -456,6 +492,7 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -466,15 +503,25 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -496,9 +543,12 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -528,6 +578,7 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -538,15 +589,25 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -569,9 +630,12 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings(
|
||||
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
||||
XXX_START_DOCSTRING,
|
||||
XXX_INPUTS_DOCSTRING,
|
||||
)
|
||||
class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -613,6 +677,7 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(XxxForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -622,15 +687,26 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
start_positions=None, end_positions=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from __future__ import print_function
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
@@ -27,46 +27,57 @@ from transformers import XxxConfig, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_xxx import (TFXxxModel, TFXxxForMaskedLM,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
TFXxxForQuestionAnswering,
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from transformers.modeling_tf_xxx import (
|
||||
TFXxxModel,
|
||||
TFXxxForMaskedLM,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
TFXxxForQuestionAnswering,
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes = (TFXxxModel, TFXxxForMaskedLM, TFXxxForQuestionAnswering,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification) if is_tf_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
TFXxxModel,
|
||||
TFXxxForMaskedLM,
|
||||
TFXxxForQuestionAnswering,
|
||||
TFXxxForSequenceClassification,
|
||||
TFXxxForTokenClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class TFXxxModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
@@ -120,15 +131,16 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFXxxModel(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
@@ -141,78 +153,74 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
"pooled_output": pooled_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFXxxForMaskedLM(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
prediction_scores, = model(inputs)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(prediction_scores,) = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFXxxForSequenceClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFXxxForTokenClassification(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
logits, = model(inputs)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFXxxForQuestionAnswering(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].shape),
|
||||
[self.batch_size, self.seq_length])
|
||||
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
@@ -244,9 +252,10 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ['xxx-base-uncased']:
|
||||
for model_name in ["xxx-base-uncased"]:
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -20,51 +20,60 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .modeling_common_test import CommonTestCases, ids_tensor
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
||||
XxxForNextSentencePrediction, XxxForPreTraining,
|
||||
XxxForQuestionAnswering, XxxForSequenceClassification,
|
||||
XxxForTokenClassification, XxxForMultipleChoice)
|
||||
from transformers import (
|
||||
XxxConfig,
|
||||
XxxModel,
|
||||
XxxForMaskedLM,
|
||||
XxxForNextSentencePrediction,
|
||||
XxxForPreTraining,
|
||||
XxxForQuestionAnswering,
|
||||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification,
|
||||
XxxForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@require_torch
|
||||
class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering,
|
||||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class XxxModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
@@ -118,16 +127,17 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = XxxModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -140,83 +150,98 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = XxxForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = XxxForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = XxxForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_xxx_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = XxxForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
@@ -252,5 +277,6 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -18,10 +18,11 @@ import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_bert import (XxxTokenizer, VOCAB_FILES_NAMES)
|
||||
from transformers.tokenization_bert import XxxTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
|
||||
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = XxxTokenizer
|
||||
@@ -30,28 +31,39 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
super(XxxTokenizationTest, self).setUp()
|
||||
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ",", "low", "lowest",
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -34,17 +34,16 @@ logger = logging.getLogger(__name__)
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to file names for serializing Tokenizer instances
|
||||
####################################################
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
####################################################
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to pretrained vocabulary URL for all the model shortcut names.
|
||||
####################################################
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'vocab_file':
|
||||
{
|
||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
||||
"vocab_file": {
|
||||
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
||||
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +51,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
# Mapping from model shortcut names to max length of inputs
|
||||
####################################################
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xxx-base-uncased': 512,
|
||||
'xxx-large-uncased': 512,
|
||||
"xxx-base-uncased": 512,
|
||||
"xxx-large-uncased": 512,
|
||||
}
|
||||
|
||||
####################################################
|
||||
@@ -62,8 +61,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
# To be used for checkpoint specific configurations.
|
||||
####################################################
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'xxx-base-uncased': {'do_lower_case': True},
|
||||
'xxx-large-uncased': {'do_lower_case': True},
|
||||
"xxx-base-uncased": {"do_lower_case": True},
|
||||
"xxx-large-uncased": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
@@ -73,7 +72,7 @@ def load_vocab(vocab_file):
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
tokens = reader.readlines()
|
||||
for index, token in enumerate(tokens):
|
||||
token = token.rstrip('\n')
|
||||
token = token.rstrip("\n")
|
||||
vocab[token] = index
|
||||
return vocab
|
||||
|
||||
@@ -93,9 +92,17 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=True,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs a XxxTokenizer.
|
||||
|
||||
Args:
|
||||
@@ -104,16 +111,22 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
Whether to lower case the input
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
"""
|
||||
super(XxxTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
super(XxxTokenizer, self).__init__(
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
**kwargs
|
||||
)
|
||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
|
||||
@property
|
||||
@@ -142,7 +155,7 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||
return out_string
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
@@ -177,8 +190,10 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError("You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model.")
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1 is not None:
|
||||
@@ -204,15 +219,17 @@ class XxxTokenizer(PreTrainedTokenizer):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
else:
|
||||
vocab_file = vocab_path
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
@@ -6,8 +6,9 @@ __version__ = "2.3.0"
|
||||
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
|
||||
try:
|
||||
import absl.logging
|
||||
absl.logging.set_verbosity('info')
|
||||
absl.logging.set_stderrthreshold('info')
|
||||
|
||||
absl.logging.set_verbosity("info")
|
||||
absl.logging.set_stderrthreshold("info")
|
||||
absl.logging._warn_preinit_stderr = False
|
||||
except:
|
||||
pass
|
||||
@@ -17,19 +18,41 @@ import logging
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Files and general utilities
|
||||
from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, MODEL_CARD_NAME,
|
||||
is_tf_available, is_torch_available)
|
||||
from .file_utils import (
|
||||
TRANSFORMERS_CACHE,
|
||||
PYTORCH_TRANSFORMERS_CACHE,
|
||||
PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path,
|
||||
add_start_docstrings,
|
||||
add_end_docstrings,
|
||||
WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
CONFIG_NAME,
|
||||
MODEL_CARD_NAME,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
SingleSentenceClassificationProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels,
|
||||
xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
|
||||
squad_convert_examples_to_features, SquadFeatures,
|
||||
SquadExample, SquadV1Processor, SquadV2Processor)
|
||||
from .data import (
|
||||
is_sklearn_available,
|
||||
InputExample,
|
||||
InputFeatures,
|
||||
DataProcessor,
|
||||
SingleSentenceClassificationProcessor,
|
||||
glue_output_modes,
|
||||
glue_convert_examples_to_features,
|
||||
glue_processors,
|
||||
glue_tasks_num_labels,
|
||||
xnli_output_modes,
|
||||
xnli_processors,
|
||||
xnli_tasks_num_labels,
|
||||
squad_convert_examples_to_features,
|
||||
SquadFeatures,
|
||||
SquadExample,
|
||||
SquadV1Processor,
|
||||
SquadV2Processor,
|
||||
)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics, xnli_compute_metrics
|
||||
@@ -38,12 +61,12 @@ if is_sklearn_available():
|
||||
from .modelcard import ModelCard
|
||||
|
||||
# Tokenizers
|
||||
from .tokenization_utils import (PreTrainedTokenizer)
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLCorpus
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
||||
@@ -75,143 +98,281 @@ from .configuration_mmbt import MMBTConfig
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
||||
AutoModelWithLMHead, AutoModelForTokenClassification, ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
|
||||
from .modeling_auto import (
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelWithLMHead,
|
||||
AutoModelForTokenClassification,
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForMultipleChoice,
|
||||
BertForTokenClassification, BertForQuestionAnswering,
|
||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
||||
AdaptiveEmbedding,
|
||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_ctrl import (CTRLPreTrainedModel, CTRLModel,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification, XLNetForTokenClassification,
|
||||
XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple,
|
||||
XLNetForQuestionAnswering, load_tf_weights_in_xlnet,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
||||
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
||||
RobertaForTokenClassification, RobertaForQuestionAnswering,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_camembert import (CamembertForMaskedLM, CamembertModel,
|
||||
CamembertForSequenceClassification, CamembertForMultipleChoice,
|
||||
CamembertForTokenClassification,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_bert import (
|
||||
BertPreTrainedModel,
|
||||
BertModel,
|
||||
BertForPreTraining,
|
||||
BertForMaskedLM,
|
||||
BertForNextSentencePrediction,
|
||||
BertForSequenceClassification,
|
||||
BertForMultipleChoice,
|
||||
BertForTokenClassification,
|
||||
BertForQuestionAnswering,
|
||||
load_tf_weights_in_bert,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_openai import (
|
||||
OpenAIGPTPreTrainedModel,
|
||||
OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
load_tf_weights_in_openai_gpt,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_transfo_xl import (
|
||||
TransfoXLPreTrainedModel,
|
||||
TransfoXLModel,
|
||||
TransfoXLLMHeadModel,
|
||||
AdaptiveEmbedding,
|
||||
load_tf_weights_in_transfo_xl,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_gpt2 import (
|
||||
GPT2PreTrainedModel,
|
||||
GPT2Model,
|
||||
GPT2LMHeadModel,
|
||||
GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_xlnet import (
|
||||
XLNetPreTrainedModel,
|
||||
XLNetModel,
|
||||
XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetForTokenClassification,
|
||||
XLNetForMultipleChoice,
|
||||
XLNetForQuestionAnsweringSimple,
|
||||
XLNetForQuestionAnswering,
|
||||
load_tf_weights_in_xlnet,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_xlm import (
|
||||
XLMPreTrainedModel,
|
||||
XLMModel,
|
||||
XLMWithLMHeadModel,
|
||||
XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering,
|
||||
XLMForQuestionAnsweringSimple,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_roberta import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaModel,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaForTokenClassification,
|
||||
RobertaForQuestionAnswering,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_distilbert import (
|
||||
DistilBertPreTrainedModel,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertModel,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_camembert import (
|
||||
CamembertForMaskedLM,
|
||||
CamembertModel,
|
||||
CamembertForSequenceClassification,
|
||||
CamembertForMultipleChoice,
|
||||
CamembertForTokenClassification,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
from .modeling_t5 import (T5PreTrainedModel, T5Model, T5WithLMHeadModel,
|
||||
load_tf_weights_in_t5,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_albert import (AlbertPreTrainedModel, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification,
|
||||
AlbertForQuestionAnswering,
|
||||
load_tf_weights_in_albert, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm_roberta import (XLMRobertaForMaskedLM, XLMRobertaModel, XLMRobertaForMultipleChoice,
|
||||
XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification)
|
||||
from .modeling_t5 import (
|
||||
T5PreTrainedModel,
|
||||
T5Model,
|
||||
T5WithLMHeadModel,
|
||||
load_tf_weights_in_t5,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_albert import (
|
||||
AlbertPreTrainedModel,
|
||||
AlbertModel,
|
||||
AlbertForMaskedLM,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertForQuestionAnswering,
|
||||
load_tf_weights_in_albert,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_xlm_roberta import (
|
||||
XLMRobertaForMaskedLM,
|
||||
XLMRobertaModel,
|
||||
XLMRobertaForMultipleChoice,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaForTokenClassification,
|
||||
)
|
||||
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
|
||||
|
||||
# Optimization
|
||||
from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup)
|
||||
from .optimization import (
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
|
||||
# TensorFlow
|
||||
if is_tf_available():
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
|
||||
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelWithLMHead, TFAutoModelForTokenClassification, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_auto import (
|
||||
TFAutoModel,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelWithLMHead,
|
||||
TFAutoModelForTokenClassification,
|
||||
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings,
|
||||
TFBertModel, TFBertForPreTraining,
|
||||
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
||||
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
||||
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_bert import (
|
||||
TFBertPreTrainedModel,
|
||||
TFBertMainLayer,
|
||||
TFBertEmbeddings,
|
||||
TFBertModel,
|
||||
TFBertForPreTraining,
|
||||
TFBertForMaskedLM,
|
||||
TFBertForNextSentencePrediction,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertForMultipleChoice,
|
||||
TFBertForTokenClassification,
|
||||
TFBertForQuestionAnswering,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
|
||||
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_gpt2 import (
|
||||
TFGPT2PreTrainedModel,
|
||||
TFGPT2MainLayer,
|
||||
TFGPT2Model,
|
||||
TFGPT2LMHeadModel,
|
||||
TFGPT2DoubleHeadsModel,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
|
||||
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_openai import (
|
||||
TFOpenAIGPTPreTrainedModel,
|
||||
TFOpenAIGPTMainLayer,
|
||||
TFOpenAIGPTModel,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
||||
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_transfo_xl import (
|
||||
TFTransfoXLPreTrainedModel,
|
||||
TFTransfoXLMainLayer,
|
||||
TFTransfoXLModel,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
|
||||
TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_xlnet import (
|
||||
TFXLNetPreTrainedModel,
|
||||
TFXLNetMainLayer,
|
||||
TFXLNetModel,
|
||||
TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
|
||||
TFXLMModel, TFXLMWithLMHeadModel,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_xlm import (
|
||||
TFXLMPreTrainedModel,
|
||||
TFXLMMainLayer,
|
||||
TFXLMModel,
|
||||
TFXLMWithLMHeadModel,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
|
||||
TFRobertaModel, TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_roberta import (
|
||||
TFRobertaPreTrainedModel,
|
||||
TFRobertaMainLayer,
|
||||
TFRobertaModel,
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
||||
TFDistilBertModel, TFDistilBertForMaskedLM,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_distilbert import (
|
||||
TFDistilBertPreTrainedModel,
|
||||
TFDistilBertMainLayer,
|
||||
TFDistilBertModel,
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
|
||||
TFCTRLLMHeadModel,
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_ctrl import (
|
||||
TFCTRLPreTrainedModel,
|
||||
TFCTRLModel,
|
||||
TFCTRLLMHeadModel,
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_albert import (TFAlbertPreTrainedModel, TFAlbertModel, TFAlbertForMaskedLM,
|
||||
TFAlbertForSequenceClassification,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_albert import (
|
||||
TFAlbertPreTrainedModel,
|
||||
TFAlbertModel,
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForSequenceClassification,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_t5 import (TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel,
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_tf_t5 import TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
# Optimization
|
||||
from .optimization_tf import (WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator)
|
||||
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
||||
|
||||
# TF 2.0 <=> PyTorch conversion utilities
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
load_pytorch_model_in_tf2_model,
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model)
|
||||
from .modeling_tf_pytorch_utils import (
|
||||
convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
load_pytorch_model_in_tf2_model,
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model,
|
||||
)
|
||||
|
||||
# Pipelines
|
||||
from .pipelines import pipeline, PipelineDataFormat, CsvPipelineDataFormat, JsonPipelineDataFormat, PipedPipelineDataFormat, \
|
||||
Pipeline, FeatureExtractionPipeline, QuestionAnsweringPipeline, NerPipeline, TextClassificationPipeline
|
||||
from .pipelines import (
|
||||
pipeline,
|
||||
PipelineDataFormat,
|
||||
CsvPipelineDataFormat,
|
||||
JsonPipelineDataFormat,
|
||||
PipedPipelineDataFormat,
|
||||
Pipeline,
|
||||
FeatureExtractionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
NerPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
||||
"Models won't be available and only tokenizers, configuration"
|
||||
"and file/data utilities can be used.")
|
||||
logger.warning(
|
||||
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
||||
"Models won't be available and only tokenizers, configuration"
|
||||
"and file/data utilities can be used."
|
||||
)
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
# coding: utf8
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
|
||||
print(
|
||||
"First argument to `transformers` command line interface should be one of: \n"
|
||||
">> convert serve train predict")
|
||||
"First argument to `transformers` command line interface should be one of: \n"
|
||||
">> convert serve train predict"
|
||||
)
|
||||
if sys.argv[1] == "convert":
|
||||
from transformers.commands import convert
|
||||
|
||||
convert(sys.argv)
|
||||
elif sys.argv[1] == "train":
|
||||
from transformers.commands import train
|
||||
|
||||
train(sys.argv)
|
||||
elif sys.argv[1] == "serve":
|
||||
pass
|
||||
@@ -19,7 +24,6 @@ def main():
|
||||
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
|
||||
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
||||
|
||||
|
||||
# # Register commands
|
||||
# ServeCommand.register_subcommand(commands_parser)
|
||||
|
||||
@@ -33,5 +37,6 @@ def main():
|
||||
# service = args.func(args)
|
||||
# service.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
class BaseTransformersCLICommand(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
||||
@@ -11,12 +11,12 @@ def convert_command_factory(args: Namespace):
|
||||
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return ConvertCommand(args.model_type, args.tf_checkpoint, args.pytorch_dump_output,
|
||||
args.config, args.finetuning_task_name)
|
||||
return ConvertCommand(
|
||||
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
|
||||
)
|
||||
|
||||
|
||||
class ConvertCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
@@ -24,25 +24,39 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser('convert', help="CLI tool to run convert model from original "
|
||||
"author checkpoints to Transformesr PyTorch checkpoints.")
|
||||
train_parser.add_argument('--model_type', type=str, required=True,
|
||||
help='Model\'s type.')
|
||||
train_parser.add_argument('--tf_checkpoint', type=str, required=True,
|
||||
help='TensorFlow checkpoint path or folder.')
|
||||
train_parser.add_argument('--pytorch_dump_output', type=str, required=True,
|
||||
help='Path to the PyTorch savd model output.')
|
||||
train_parser.add_argument('--config', type=str, default="",
|
||||
help='Configuration file path or folder.')
|
||||
train_parser.add_argument('--finetuning_task_name', type=str, default=None,
|
||||
help='Optional fine-tuning task name if the TF model was a finetuned model.')
|
||||
train_parser = parser.add_parser(
|
||||
"convert",
|
||||
help="CLI tool to run convert model from original "
|
||||
"author checkpoints to Transformesr PyTorch checkpoints.",
|
||||
)
|
||||
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
|
||||
train_parser.add_argument(
|
||||
"--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
|
||||
)
|
||||
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
|
||||
train_parser.add_argument(
|
||||
"--finetuning_task_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional fine-tuning task name if the TF model was a finetuned model.",
|
||||
)
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
def __init__(self, model_type: str, tf_checkpoint: str, pytorch_dump_output: str,
|
||||
config: str, finetuning_task_name: str, *args):
|
||||
self._logger = getLogger('transformers-cli/converting')
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
tf_checkpoint: str,
|
||||
pytorch_dump_output: str,
|
||||
config: str,
|
||||
finetuning_task_name: str,
|
||||
*args
|
||||
):
|
||||
self._logger = getLogger("transformers-cli/converting")
|
||||
|
||||
self._logger.info('Loading model {}'.format(model_type))
|
||||
self._logger.info("Loading model {}".format(model_type))
|
||||
self._model_type = model_type
|
||||
self._tf_checkpoint = tf_checkpoint
|
||||
self._pytorch_dump_output = pytorch_dump_output
|
||||
@@ -52,63 +66,80 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
def run(self):
|
||||
if self._model_type == "bert":
|
||||
try:
|
||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint,
|
||||
self._config,
|
||||
self._pytorch_dump_output)
|
||||
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
convert_openai_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "transfo_xl":
|
||||
try:
|
||||
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
convert_transfo_xl_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if 'ckpt' in self._tf_checkpoint.lower():
|
||||
if "ckpt" in self._tf_checkpoint.lower():
|
||||
TF_CHECKPOINT = self._tf_checkpoint
|
||||
TF_DATASET_FILE = ""
|
||||
else:
|
||||
TF_DATASET_FILE = self._tf_checkpoint
|
||||
TF_CHECKPOINT = ""
|
||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT,
|
||||
self._config,
|
||||
self._pytorch_dump_output,
|
||||
TF_DATASET_FILE)
|
||||
convert_transfo_xl_checkpoint_to_pytorch(
|
||||
TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
|
||||
)
|
||||
elif self._model_type == "gpt2":
|
||||
try:
|
||||
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
||||
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
convert_gpt2_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
try:
|
||||
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
||||
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
convert_xlnet_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_xlnet_checkpoint_to_pytorch(self._tf_checkpoint,
|
||||
self._config,
|
||||
self._pytorch_dump_output,
|
||||
self._finetuning_task_name)
|
||||
convert_xlnet_checkpoint_to_pytorch(
|
||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||
)
|
||||
elif self._model_type == "xlm":
|
||||
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
|
||||
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_xlm_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||
else:
|
||||
|
||||
@@ -8,13 +8,16 @@ def download_command_factory(args):
|
||||
|
||||
|
||||
class DownloadCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
download_parser = parser.add_parser('download')
|
||||
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
|
||||
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
|
||||
download_parser.add_argument('model', type=str, help='Name of the model to download')
|
||||
download_parser = parser.add_parser("download")
|
||||
download_parser.add_argument(
|
||||
"--cache-dir", type=str, default=None, help="Path to location to store the models"
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
|
||||
)
|
||||
download_parser.add_argument("model", type=str, help="Name of the model to download")
|
||||
download_parser.set_defaults(func=download_command_factory)
|
||||
|
||||
def __init__(self, model: str, cache: str, force: bool):
|
||||
|
||||
@@ -10,52 +10,72 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def try_infer_format_from_ext(path: str):
|
||||
if not path:
|
||||
return 'pipe'
|
||||
return "pipe"
|
||||
|
||||
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
||||
if path.endswith(ext):
|
||||
return ext
|
||||
|
||||
raise Exception(
|
||||
'Unable to determine file format from file extension {}. '
|
||||
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
||||
"Unable to determine file format from file extension {}. "
|
||||
"Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
||||
)
|
||||
|
||||
|
||||
def run_command_factory(args):
|
||||
nlp = pipeline(task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device)
|
||||
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
|
||||
reader = PipelineDataFormat.from_str(format=format,
|
||||
output_path=args.output,
|
||||
input_path=args.input,
|
||||
column=args.column if args.column else nlp.default_input_names,
|
||||
overwrite=args.overwrite)
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
)
|
||||
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
|
||||
reader = PipelineDataFormat.from_str(
|
||||
format=format,
|
||||
output_path=args.output,
|
||||
input_path=args.input,
|
||||
column=args.column if args.column else nlp.default_input_names,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
return RunCommand(nlp, reader)
|
||||
|
||||
|
||||
class RunCommand(BaseTransformersCLICommand):
|
||||
|
||||
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
||||
self._nlp = nlp
|
||||
self._reader = reader
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
|
||||
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
|
||||
run_parser.add_argument('--input', type=str, help='Path to the file to use for inference')
|
||||
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
|
||||
run_parser.add_argument('--model', type=str, help='Name or path to the model to instantiate.')
|
||||
run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.')
|
||||
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
|
||||
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
|
||||
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
|
||||
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
||||
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
|
||||
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
||||
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
|
||||
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
||||
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
||||
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
||||
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
|
||||
run_parser.add_argument(
|
||||
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--column",
|
||||
type=str,
|
||||
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
default="infer",
|
||||
choices=PipelineDataFormat.SUPPORTED_FORMATS,
|
||||
help="Input format to read from",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
|
||||
run_parser.set_defaults(func=run_command_factory)
|
||||
|
||||
def run(self):
|
||||
@@ -71,9 +91,6 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
# Saving data
|
||||
if self._nlp.binary_output:
|
||||
binary_path = self._reader.save_binary(outputs)
|
||||
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
|
||||
logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
|
||||
else:
|
||||
self._reader.save(outputs)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ try:
|
||||
from uvicorn import run
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from pydantic import BaseModel
|
||||
|
||||
_serve_dependancies_installed = True
|
||||
except (ImportError, AttributeError):
|
||||
BaseModel = object
|
||||
@@ -17,18 +18,21 @@ from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
logger = logging.getLogger('transformers-cli/serving')
|
||||
logger = logging.getLogger("transformers-cli/serving")
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
nlp = pipeline(task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device)
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
)
|
||||
return ServeCommand(nlp, args.host, args.port)
|
||||
|
||||
|
||||
@@ -36,6 +40,7 @@ class ServeModelInfoResult(BaseModel):
|
||||
"""
|
||||
Expose model information
|
||||
"""
|
||||
|
||||
infos: dict
|
||||
|
||||
|
||||
@@ -43,6 +48,7 @@ class ServeTokenizeResult(BaseModel):
|
||||
"""
|
||||
Tokenize result model
|
||||
"""
|
||||
|
||||
tokens: List[str]
|
||||
tokens_ids: Optional[List[int]]
|
||||
|
||||
@@ -51,6 +57,7 @@ class ServeDeTokenizeResult(BaseModel):
|
||||
"""
|
||||
DeTokenize result model
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
@@ -58,11 +65,11 @@ class ServeForwardResult(BaseModel):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
|
||||
output: Any
|
||||
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
@@ -70,14 +77,23 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
|
||||
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
|
||||
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
|
||||
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
|
||||
serve_parser.add_argument('--model', type=str, help='Model\'s name or path to stored model.')
|
||||
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
|
||||
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
|
||||
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
||||
serve_parser = parser.add_parser(
|
||||
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||
serve_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
||||
@@ -87,18 +103,22 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
self._host = host
|
||||
self._port = port
|
||||
if not _serve_dependancies_installed:
|
||||
raise ImportError("Using serve command requires FastAPI and unicorn. "
|
||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||
"Or install FastAPI and unicorn separatly.")
|
||||
raise ImportError(
|
||||
"Using serve command requires FastAPI and unicorn. "
|
||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||
"Or install FastAPI and unicorn separatly."
|
||||
)
|
||||
else:
|
||||
logger.info('Serving model over {}:{}'.format(host, port))
|
||||
logger.info("Serving model over {}:{}".format(host, port))
|
||||
self._app = FastAPI()
|
||||
|
||||
# Register routes
|
||||
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
|
||||
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
|
||||
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
|
||||
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
|
||||
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
|
||||
self._app.add_api_route(
|
||||
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
|
||||
)
|
||||
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self._host, port=self._port)
|
||||
@@ -122,11 +142,14 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
return ServeTokenizeResult(tokens=tokens_txt)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
|
||||
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True)):
|
||||
def detokenize(
|
||||
self,
|
||||
tokens_ids: List[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
||||
):
|
||||
"""
|
||||
Detokenize the provided tokens ids to readable text:
|
||||
- **tokens_ids**: List of tokens ids
|
||||
@@ -135,9 +158,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model='', text=decoded_str)
|
||||
return ServeDeTokenizeResult(model="", text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
|
||||
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
||||
"""
|
||||
|
||||
@@ -3,9 +3,12 @@ from argparse import ArgumentParser, Namespace
|
||||
from logging import getLogger
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers import (is_tf_available, is_torch_available,
|
||||
TextClassificationPipeline,
|
||||
SingleSentenceClassificationProcessor as Processor)
|
||||
from transformers import (
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
TextClassificationPipeline,
|
||||
SingleSentenceClassificationProcessor as Processor,
|
||||
)
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||
@@ -14,6 +17,7 @@ if not is_tf_available() and not is_torch_available():
|
||||
USE_XLA = False
|
||||
USE_AMP = False
|
||||
|
||||
|
||||
def train_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
@@ -23,7 +27,6 @@ def train_command_factory(args: Namespace):
|
||||
|
||||
|
||||
class TrainCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
@@ -31,47 +34,54 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
|
||||
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
|
||||
|
||||
train_parser.add_argument('--train_data', type=str, required=True,
|
||||
help="path to train (and optionally evaluation) dataset as a csv with "
|
||||
"tab separated labels and sentences.")
|
||||
train_parser.add_argument('--column_label', type=int, default=0,
|
||||
help='Column of the dataset csv file with example labels.')
|
||||
train_parser.add_argument('--column_text', type=int, default=1,
|
||||
help='Column of the dataset csv file with example texts.')
|
||||
train_parser.add_argument('--column_id', type=int, default=2,
|
||||
help='Column of the dataset csv file with example ids.')
|
||||
train_parser.add_argument('--skip_first_row', action='store_true',
|
||||
help='Skip the first row of the csv file (headers).')
|
||||
train_parser.add_argument(
|
||||
"--train_data",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to train (and optionally evaluation) dataset as a csv with "
|
||||
"tab separated labels and sentences.",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
|
||||
)
|
||||
|
||||
train_parser.add_argument('--validation_data', type=str, default='',
|
||||
help='path to validation dataset.')
|
||||
train_parser.add_argument('--validation_split', type=float, default=0.1,
|
||||
help="if validation dataset is not provided, fraction of train dataset "
|
||||
"to use as validation dataset.")
|
||||
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
|
||||
train_parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
|
||||
)
|
||||
|
||||
train_parser.add_argument('--output', type=str, default='./',
|
||||
help='path to saved the trained model.')
|
||||
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
|
||||
|
||||
train_parser.add_argument('--task', type=str, default='text_classification',
|
||||
help='Task to train the model on.')
|
||||
train_parser.add_argument('--model', type=str, default='bert-base-uncased',
|
||||
help='Model\'s name or path to stored model.')
|
||||
train_parser.add_argument('--train_batch_size', type=int, default=32,
|
||||
help='Batch size for training.')
|
||||
train_parser.add_argument('--valid_batch_size', type=int, default=64,
|
||||
help='Batch size for validation.')
|
||||
train_parser.add_argument('--learning_rate', type=float, default=3e-5,
|
||||
help="Learning rate.")
|
||||
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
train_parser.add_argument(
|
||||
"--task", type=str, default="text_classification", help="Task to train the model on."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
|
||||
)
|
||||
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
|
||||
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
|
||||
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
|
||||
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
|
||||
train_parser.set_defaults(func=train_command_factory)
|
||||
|
||||
def __init__(self, args: Namespace):
|
||||
self.logger = getLogger('transformers-cli/training')
|
||||
self.logger = getLogger("transformers-cli/training")
|
||||
|
||||
self.framework = 'tf' if is_tf_available() else 'torch'
|
||||
self.framework = "tf" if is_tf_available() else "torch"
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
assert os.path.isdir(args.output)
|
||||
@@ -81,28 +91,32 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
self.column_text = args.column_text
|
||||
self.column_id = args.column_id
|
||||
|
||||
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model))
|
||||
if args.task == 'text_classification':
|
||||
self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
|
||||
if args.task == "text_classification":
|
||||
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
||||
elif args.task == 'token_classification':
|
||||
elif args.task == "token_classification":
|
||||
raise NotImplementedError
|
||||
elif args.task == 'question_answering':
|
||||
elif args.task == "question_answering":
|
||||
raise NotImplementedError
|
||||
|
||||
self.logger.info('Loading dataset from {}'.format(args.train_data))
|
||||
self.train_dataset = Processor.create_from_csv(args.train_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row)
|
||||
self.logger.info("Loading dataset from {}".format(args.train_data))
|
||||
self.train_dataset = Processor.create_from_csv(
|
||||
args.train_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row,
|
||||
)
|
||||
self.valid_dataset = None
|
||||
if args.validation_data:
|
||||
self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
|
||||
self.valid_dataset = Processor.create_from_csv(args.validation_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row)
|
||||
self.logger.info("Loading validation dataset from {}".format(args.validation_data))
|
||||
self.valid_dataset = Processor.create_from_csv(
|
||||
args.validation_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row,
|
||||
)
|
||||
|
||||
self.validation_split = args.validation_split
|
||||
self.train_batch_size = args.train_batch_size
|
||||
@@ -111,7 +125,7 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
self.adam_epsilon = args.adam_epsilon
|
||||
|
||||
def run(self):
|
||||
if self.framework == 'tf':
|
||||
if self.framework == "tf":
|
||||
return self.run_tf()
|
||||
return self.run_torch()
|
||||
|
||||
@@ -119,13 +133,15 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
raise NotImplementedError
|
||||
|
||||
def run_tf(self):
|
||||
self.pipeline.fit(self.train_dataset,
|
||||
validation_data=self.valid_dataset,
|
||||
validation_split=self.validation_split,
|
||||
learning_rate=self.learning_rate,
|
||||
adam_epsilon=self.adam_epsilon,
|
||||
train_batch_size=self.train_batch_size,
|
||||
valid_batch_size=self.valid_batch_size)
|
||||
self.pipeline.fit(
|
||||
self.train_dataset,
|
||||
validation_data=self.valid_dataset,
|
||||
validation_split=self.validation_split,
|
||||
learning_rate=self.learning_rate,
|
||||
adam_epsilon=self.adam_epsilon,
|
||||
train_batch_size=self.train_batch_size,
|
||||
valid_batch_size=self.valid_batch_size,
|
||||
)
|
||||
|
||||
# Save trained pipeline
|
||||
self.pipeline.save_pretrained(self.output)
|
||||
|
||||
@@ -9,28 +9,31 @@ from transformers.hf_api import HfApi, HfFolder, HTTPError
|
||||
class UserCommands(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
login_parser = parser.add_parser('login')
|
||||
login_parser = parser.add_parser("login")
|
||||
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
||||
whoami_parser = parser.add_parser('whoami')
|
||||
whoami_parser = parser.add_parser("whoami")
|
||||
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||
logout_parser = parser.add_parser('logout')
|
||||
logout_parser = parser.add_parser("logout")
|
||||
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||
list_parser = parser.add_parser('ls')
|
||||
list_parser = parser.add_parser("ls")
|
||||
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||
# upload
|
||||
upload_parser = parser.add_parser('upload')
|
||||
upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.')
|
||||
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.')
|
||||
upload_parser = parser.add_parser("upload")
|
||||
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
||||
upload_parser.add_argument(
|
||||
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
||||
)
|
||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||
|
||||
|
||||
|
||||
class ANSI:
|
||||
"""
|
||||
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||
"""
|
||||
|
||||
_bold = u"\u001b[1m"
|
||||
_reset = u"\u001b[0m"
|
||||
|
||||
@classmethod
|
||||
def bold(cls, s):
|
||||
return "{}{}{}".format(cls._bold, s, cls._reset)
|
||||
@@ -44,14 +47,16 @@ class BaseUserCommand:
|
||||
|
||||
class LoginCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
print("""
|
||||
print(
|
||||
"""
|
||||
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
||||
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
||||
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
||||
|
||||
""")
|
||||
"""
|
||||
)
|
||||
username = input("Username: ")
|
||||
password = getpass()
|
||||
try:
|
||||
@@ -101,16 +106,10 @@ class ListObjsCommand(BaseUserCommand):
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(
|
||||
row_format.format(*headers)
|
||||
)
|
||||
lines.append(
|
||||
row_format.format(*["-" * w for w in col_widths])
|
||||
)
|
||||
lines.append(row_format.format(*headers))
|
||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||
for row in rows:
|
||||
lines.append(
|
||||
row_format.format(*row)
|
||||
)
|
||||
lines.append(row_format.format(*row))
|
||||
return "\n".join(lines)
|
||||
|
||||
def run(self):
|
||||
@@ -126,15 +125,8 @@ class ListObjsCommand(BaseUserCommand):
|
||||
if len(objs) == 0:
|
||||
print("No shared file yet")
|
||||
exit()
|
||||
rows = [ [
|
||||
obj.filename,
|
||||
obj.LastModified,
|
||||
obj.ETag,
|
||||
obj.Size
|
||||
] for obj in objs ]
|
||||
print(
|
||||
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
|
||||
)
|
||||
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
||||
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||
|
||||
|
||||
class UploadCommand(BaseUserCommand):
|
||||
@@ -143,13 +135,7 @@ class UploadCommand(BaseUserCommand):
|
||||
Recursively list all files in a folder.
|
||||
"""
|
||||
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
||||
files = [
|
||||
(
|
||||
os.path.join(os.getcwd(), f.path), # filepath
|
||||
f.path # filename
|
||||
)
|
||||
for f in entries if f.is_file()
|
||||
]
|
||||
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
|
||||
for f in entries:
|
||||
if f.is_dir():
|
||||
files += self.walk_dir(f.path)
|
||||
@@ -173,22 +159,14 @@ class UploadCommand(BaseUserCommand):
|
||||
raise ValueError("Not a valid file or directory: {}".format(local_path))
|
||||
|
||||
for filepath, filename in files:
|
||||
print(
|
||||
"About to upload file {} to S3 under filename {}".format(
|
||||
ANSI.bold(filepath), ANSI.bold(filename)
|
||||
)
|
||||
)
|
||||
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
|
||||
|
||||
choice = input("Proceed? [Y/n] ").lower()
|
||||
if not(choice == "" or choice == "y" or choice == "yes"):
|
||||
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||
print("Abort")
|
||||
exit()
|
||||
print(
|
||||
ANSI.bold("Uploading... This might take a while if files are large")
|
||||
)
|
||||
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
||||
for filepath, filename in files:
|
||||
access_url = self._api.presign_and_upload(
|
||||
token=token, filename=filename, filepath=filepath
|
||||
)
|
||||
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
|
||||
print("Your file now lives at:")
|
||||
print(access_url)
|
||||
|
||||
@@ -18,16 +18,17 @@
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
|
||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
|
||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
|
||||
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
|
||||
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
||||
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
||||
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
||||
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
||||
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
|
||||
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
|
||||
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
|
||||
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
|
||||
}
|
||||
|
||||
|
||||
class AlbertConfig(PretrainedConfig):
|
||||
"""Configuration for `AlbertModel`.
|
||||
|
||||
@@ -36,22 +37,25 @@ class AlbertConfig(PretrainedConfig):
|
||||
|
||||
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=30000,
|
||||
embedding_size=128,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=12,
|
||||
num_hidden_groups=1,
|
||||
num_attention_heads=64,
|
||||
intermediate_size=16384,
|
||||
inner_group_num=1,
|
||||
hidden_act="gelu_new",
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30000,
|
||||
embedding_size=128,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=12,
|
||||
num_hidden_groups=1,
|
||||
num_attention_heads=64,
|
||||
intermediate_size=16384,
|
||||
inner_group_num=1,
|
||||
hidden_act="gelu_new",
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs AlbertConfig.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -35,7 +35,8 @@ from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@@ -50,8 +51,9 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
|
||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items())
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
|
||||
class AutoConfig(object):
|
||||
@@ -79,37 +81,42 @@ class AutoConfig(object):
|
||||
- contains `ctrl` : CTRLConfig (CTRL model)
|
||||
This class cannot be instantiated using `__init__()` (throw an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError("AutoConfig is designed to be instantiated "
|
||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.")
|
||||
raise EnvironmentError(
|
||||
"AutoConfig is designed to be instantiated "
|
||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_type, *args, **kwargs):
|
||||
if 'distilbert' in model_type:
|
||||
if "distilbert" in model_type:
|
||||
return DistilBertConfig(*args, **kwargs)
|
||||
elif 'roberta' in model_type:
|
||||
elif "roberta" in model_type:
|
||||
return RobertaConfig(*args, **kwargs)
|
||||
elif 'bert' in model_type:
|
||||
elif "bert" in model_type:
|
||||
return BertConfig(*args, **kwargs)
|
||||
elif 'openai-gpt' in model_type:
|
||||
elif "openai-gpt" in model_type:
|
||||
return OpenAIGPTConfig(*args, **kwargs)
|
||||
elif 'gpt2' in model_type:
|
||||
elif "gpt2" in model_type:
|
||||
return GPT2Config(*args, **kwargs)
|
||||
elif 'transfo-xl' in model_type:
|
||||
elif "transfo-xl" in model_type:
|
||||
return TransfoXLConfig(*args, **kwargs)
|
||||
elif 'xlnet' in model_type:
|
||||
elif "xlnet" in model_type:
|
||||
return XLNetConfig(*args, **kwargs)
|
||||
elif 'xlm' in model_type:
|
||||
elif "xlm" in model_type:
|
||||
return XLMConfig(*args, **kwargs)
|
||||
elif 'ctrl' in model_type:
|
||||
elif "ctrl" in model_type:
|
||||
return CTRLConfig(*args, **kwargs)
|
||||
elif 'albert' in model_type:
|
||||
elif "albert" in model_type:
|
||||
return AlbertConfig(*args, **kwargs)
|
||||
elif 'camembert' in model_type:
|
||||
elif "camembert" in model_type:
|
||||
return CamembertConfig(*args, **kwargs)
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
@@ -176,32 +183,36 @@ class AutoConfig(object):
|
||||
assert unused_kwargs == {'foo': False}
|
||||
|
||||
"""
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'camembert' in pretrained_model_name_or_path:
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'gpt2' in pretrained_model_name_or_path:
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif 'ctrl' in pretrained_model_name_or_path:
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
@@ -27,27 +27,27 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
||||
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
||||
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
||||
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
||||
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
|
||||
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
|
||||
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
|
||||
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
|
||||
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
||||
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
||||
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
||||
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
||||
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
||||
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
||||
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
||||
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
||||
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
||||
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
||||
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
||||
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
||||
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
||||
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
|
||||
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
|
||||
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
|
||||
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
|
||||
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
||||
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -82,20 +82,22 @@ class BertConfig(PretrainedConfig):
|
||||
"""
|
||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
**kwargs
|
||||
):
|
||||
super(BertConfig, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" CamemBERT configuration """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
@@ -25,7 +24,7 @@ from .configuration_roberta import RobertaConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
|
||||
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
|
||||
|
||||
|
||||
class CTRLConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `CTRLModel`.
|
||||
|
||||
@@ -48,6 +49,7 @@ class CTRLConfig(PretrainedConfig):
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
@@ -64,7 +66,7 @@ class CTRLConfig(PretrainedConfig):
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_range=0.02,
|
||||
summary_type='cls_index',
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
|
||||
@@ -13,8 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" DistilBERT model configuration """
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import sys
|
||||
import json
|
||||
@@ -26,32 +25,34 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
||||
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
|
||||
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
|
||||
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
|
||||
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
||||
"distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
|
||||
"distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
|
||||
"distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
|
||||
}
|
||||
|
||||
|
||||
class DistilBertConfig(PretrainedConfig):
|
||||
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=30522,
|
||||
max_position_embeddings=512,
|
||||
sinusoidal_pos_embds=False,
|
||||
n_layers=6,
|
||||
n_heads=12,
|
||||
dim=768,
|
||||
hidden_dim=4*768,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation='gelu',
|
||||
initializer_range=0.02,
|
||||
tie_weights_=True,
|
||||
qa_dropout=0.1,
|
||||
seq_classif_dropout=0.2,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_position_embeddings=512,
|
||||
sinusoidal_pos_embds=False,
|
||||
n_layers=6,
|
||||
n_heads=12,
|
||||
dim=768,
|
||||
hidden_dim=4 * 768,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation="gelu",
|
||||
initializer_range=0.02,
|
||||
tie_weights_=True,
|
||||
qa_dropout=0.1,
|
||||
seq_classif_dropout=0.2,
|
||||
**kwargs
|
||||
):
|
||||
super(DistilBertConfig, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
@@ -26,11 +26,14 @@ from .configuration_utils import PretrainedConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
|
||||
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
|
||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
|
||||
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
|
||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",
|
||||
}
|
||||
|
||||
|
||||
class GPT2Config(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `GPT2Model`.
|
||||
@@ -52,6 +55,7 @@ class GPT2Config(PretrainedConfig):
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
@@ -67,7 +71,7 @@ class GPT2Config(PretrainedConfig):
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type='cls_index',
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" MMBT configuration """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
@@ -31,6 +30,7 @@ class MMBTConfig(object):
|
||||
num_labels: Size of final Linear layer for classification.
|
||||
modal_hidden_size: Embedding dimension of the non-text modality encoder.
|
||||
"""
|
||||
|
||||
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
|
||||
self.__dict__ = config.__dict__
|
||||
self.modal_hidden_size = modal_hidden_size
|
||||
|
||||
@@ -30,6 +30,7 @@ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
|
||||
}
|
||||
|
||||
|
||||
class OpenAIGPTConfig(PretrainedConfig):
|
||||
"""
|
||||
Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||
@@ -54,6 +55,7 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
initializing all weight matrices.
|
||||
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
@@ -71,7 +73,7 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
predict_special_tokens=True,
|
||||
summary_type='cls_index',
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" RoBERTa configuration """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
@@ -25,12 +24,12 @@ from .configuration_bert import BertConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
||||
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
|
||||
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
|
||||
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
|
||||
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
||||
"roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
||||
"roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
||||
"distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
|
||||
"roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
|
||||
"roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -27,11 +27,11 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
|
||||
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
|
||||
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
|
||||
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
|
||||
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
|
||||
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
|
||||
"t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
|
||||
"t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
|
||||
"t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
|
||||
"t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -65,19 +65,21 @@ class T5Config(PretrainedConfig):
|
||||
"""
|
||||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=32128,
|
||||
n_positions=512,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
d_ff=2048,
|
||||
num_layers=6,
|
||||
num_heads=8,
|
||||
relative_attention_num_buckets=32,
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_factor=1.0,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32128,
|
||||
n_positions=512,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
d_ff=2048,
|
||||
num_layers=6,
|
||||
num_heads=8,
|
||||
relative_attention_num_buckets=32,
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_factor=1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(T5Config, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
|
||||
@@ -27,9 +27,10 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||
}
|
||||
|
||||
|
||||
class TransfoXLConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||
|
||||
@@ -65,38 +66,41 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
proj_init_std: parameters initialized by N(0, init_std)
|
||||
init_std: parameters initialized by N(0, init_std)
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=267735,
|
||||
cutoffs=[20000, 40000, 200000],
|
||||
d_model=1024,
|
||||
d_embed=1024,
|
||||
n_head=16,
|
||||
d_head=64,
|
||||
d_inner=4096,
|
||||
div_val=4,
|
||||
pre_lnorm=False,
|
||||
n_layer=18,
|
||||
tgt_len=128,
|
||||
ext_len=0,
|
||||
mem_len=1600,
|
||||
clamp_len=1000,
|
||||
same_length=True,
|
||||
proj_share_all_but_first=True,
|
||||
attn_type=0,
|
||||
sample_softmax=-1,
|
||||
adaptive=True,
|
||||
tie_weight=True,
|
||||
dropout=0.1,
|
||||
dropatt=0.0,
|
||||
untie_r=True,
|
||||
init="normal",
|
||||
init_range=0.01,
|
||||
proj_init_std=0.01,
|
||||
init_std=0.02,
|
||||
layer_norm_epsilon=1e-5,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=267735,
|
||||
cutoffs=[20000, 40000, 200000],
|
||||
d_model=1024,
|
||||
d_embed=1024,
|
||||
n_head=16,
|
||||
d_head=64,
|
||||
d_inner=4096,
|
||||
div_val=4,
|
||||
pre_lnorm=False,
|
||||
n_layer=18,
|
||||
tgt_len=128,
|
||||
ext_len=0,
|
||||
mem_len=1600,
|
||||
clamp_len=1000,
|
||||
same_length=True,
|
||||
proj_share_all_but_first=True,
|
||||
attn_type=0,
|
||||
sample_softmax=-1,
|
||||
adaptive=True,
|
||||
tie_weight=True,
|
||||
dropout=0.1,
|
||||
dropatt=0.0,
|
||||
untie_r=True,
|
||||
init="normal",
|
||||
init_range=0.01,
|
||||
proj_init_std=0.01,
|
||||
init_std=0.02,
|
||||
layer_norm_epsilon=1e-5,
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs TransfoXLConfig.
|
||||
"""
|
||||
super(TransfoXLConfig, self).__init__(**kwargs)
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" Configuration base class and utilities."""
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import copy
|
||||
import json
|
||||
@@ -28,6 +27,7 @@ from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PretrainedConfig(object):
|
||||
r""" Base class for all configuration classes.
|
||||
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
||||
@@ -50,36 +50,36 @@ class PretrainedConfig(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Attributes with defaults
|
||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||
self.output_past = kwargs.pop('output_past', True) # Not used by all models
|
||||
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
|
||||
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||
self.output_past = kwargs.pop("output_past", True) # Not used by all models
|
||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||
|
||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.is_decoder = kwargs.pop('is_decoder', False)
|
||||
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||
|
||||
# Parameters for sequence generation
|
||||
self.max_length = kwargs.pop('max_length', 20)
|
||||
self.do_sample = kwargs.pop('do_sample', False)
|
||||
self.num_beams = kwargs.pop('num_beams', 1)
|
||||
self.temperature = kwargs.pop('temperature', 1.0)
|
||||
self.top_k = kwargs.pop('top_k', 50)
|
||||
self.top_p = kwargs.pop('top_p', 1.0)
|
||||
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
|
||||
self.bos_token_id = kwargs.pop('bos_token_id', 0)
|
||||
self.pad_token_id = kwargs.pop('pad_token_id', 0)
|
||||
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
|
||||
self.length_penalty = kwargs.pop('length_penalty', 1.)
|
||||
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
self.do_sample = kwargs.pop("do_sample", False)
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
self.top_p = kwargs.pop("top_p", 1.0)
|
||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", 0)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", 0)
|
||||
self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
||||
self.num_labels = kwargs.pop('num_labels', 2)
|
||||
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
|
||||
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
||||
self.num_labels = kwargs.pop("num_labels", 2)
|
||||
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
|
||||
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys())))
|
||||
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
|
||||
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
||||
|
||||
# Additional attributes without default values
|
||||
@@ -94,7 +94,9 @@ class PretrainedConfig(object):
|
||||
""" Save a configuration object to the directory `save_directory`, so that it
|
||||
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
assert os.path.isdir(
|
||||
save_directory
|
||||
), "Saving path should be a directory where the model and configuration can be saved"
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
@@ -153,11 +155,11 @@ class PretrainedConfig(object):
|
||||
assert unused_kwargs == {'foo': False}
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
@@ -170,37 +172,48 @@ class PretrainedConfig(object):
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
|
||||
proxies=proxies, resume_download=resume_download)
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
)
|
||||
# Load config
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
config_file)
|
||||
config_file
|
||||
)
|
||||
else:
|
||||
msg = "Model name '{}' was not found in model name list ({}). " \
|
||||
"We assumed '{}' was a path or url to a configuration file named {} or " \
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
msg = (
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url to a configuration file named {} or "
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(cls.pretrained_config_archive_map.keys()),
|
||||
config_file, CONFIG_NAME)
|
||||
", ".join(cls.pretrained_config_archive_map.keys()),
|
||||
config_file,
|
||||
CONFIG_NAME,
|
||||
)
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
msg = "Couldn't reach server at '{}' to download configuration file or " \
|
||||
"configuration file is not a valid JSON file. " \
|
||||
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
||||
msg = (
|
||||
"Couldn't reach server at '{}' to download configuration file or "
|
||||
"configuration file is not a valid JSON file. "
|
||||
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info("loading configuration file {}".format(config_file))
|
||||
else:
|
||||
logger.info("loading configuration file {} from cache at {}".format(
|
||||
config_file, resolved_config_file))
|
||||
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
|
||||
|
||||
if hasattr(config, 'pruned_heads'):
|
||||
if hasattr(config, "pruned_heads"):
|
||||
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
||||
|
||||
# Update config with kwargs if needed
|
||||
@@ -226,7 +239,7 @@ class PretrainedConfig(object):
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `Config` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
dict_obj = json.loads(text)
|
||||
return cls(**dict_obj)
|
||||
@@ -248,5 +261,5 @@ class PretrainedConfig(object):
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
@@ -25,16 +25,16 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
|
||||
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
|
||||
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
|
||||
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
|
||||
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
|
||||
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
|
||||
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
|
||||
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
|
||||
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
|
||||
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
|
||||
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
|
||||
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
|
||||
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -78,41 +78,44 @@ class XLMConfig(PretrainedConfig):
|
||||
-1 means no clamping.
|
||||
same_length: bool, whether to use the same attention length for each token.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=30145,
|
||||
emb_dim=2048,
|
||||
n_layers=12,
|
||||
n_heads=16,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
gelu_activation=True,
|
||||
sinusoidal_embeddings=False,
|
||||
causal=False,
|
||||
asm=False,
|
||||
n_langs=1,
|
||||
use_lang_emb=True,
|
||||
max_position_embeddings=512,
|
||||
embed_init_std=2048 ** -0.5,
|
||||
layer_norm_eps=1e-12,
|
||||
init_std=0.02,
|
||||
bos_index=0,
|
||||
eos_index=1,
|
||||
pad_index=2,
|
||||
unk_index=3,
|
||||
mask_index=5,
|
||||
is_encoder=True,
|
||||
summary_type='first',
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
start_n_top=5,
|
||||
end_n_top=5,
|
||||
mask_token_id=0,
|
||||
lang_id=0,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30145,
|
||||
emb_dim=2048,
|
||||
n_layers=12,
|
||||
n_heads=16,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
gelu_activation=True,
|
||||
sinusoidal_embeddings=False,
|
||||
causal=False,
|
||||
asm=False,
|
||||
n_langs=1,
|
||||
use_lang_emb=True,
|
||||
max_position_embeddings=512,
|
||||
embed_init_std=2048 ** -0.5,
|
||||
layer_norm_eps=1e-12,
|
||||
init_std=0.02,
|
||||
bos_index=0,
|
||||
eos_index=1,
|
||||
pad_index=2,
|
||||
unk_index=3,
|
||||
mask_index=5,
|
||||
is_encoder=True,
|
||||
summary_type="first",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
start_n_top=5,
|
||||
end_n_top=5,
|
||||
mask_token_id=0,
|
||||
lang_id=0,
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs XLMConfig.
|
||||
"""
|
||||
super(XLMConfig, self).__init__(**kwargs)
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" XLM-RoBERTa configuration """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
@@ -25,12 +24,12 @@ from .configuration_roberta import RobertaConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
|
||||
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
|
||||
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
|
||||
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
|
||||
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
|
||||
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
|
||||
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
|
||||
"xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
|
||||
"xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
|
||||
"xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
|
||||
"xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
|
||||
"xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ from .configuration_utils import PretrainedConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
||||
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
||||
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -69,32 +69,35 @@ class XLNetConfig(PretrainedConfig):
|
||||
same_length: bool, whether to use the same attention length for each token.
|
||||
finetuning_task: name of the glue task on which the model was fine-tuned if any
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=32000,
|
||||
d_model=1024,
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
d_inner=4096,
|
||||
ff_activation="gelu",
|
||||
untie_r=True,
|
||||
attn_type="bi",
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
dropout=0.1,
|
||||
mem_len=None,
|
||||
reuse_len=None,
|
||||
bi_data=False,
|
||||
clamp_len=-1,
|
||||
same_length=False,
|
||||
summary_type='last',
|
||||
summary_use_proj=True,
|
||||
summary_activation='tanh',
|
||||
summary_last_dropout=0.1,
|
||||
start_n_top=5,
|
||||
end_n_top=5,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
d_model=1024,
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
d_inner=4096,
|
||||
ff_activation="gelu",
|
||||
untie_r=True,
|
||||
attn_type="bi",
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
dropout=0.1,
|
||||
mem_len=None,
|
||||
reuse_len=None,
|
||||
bi_data=False,
|
||||
clamp_len=-1,
|
||||
same_length=False,
|
||||
summary_type="last",
|
||||
summary_use_proj=True,
|
||||
summary_activation="tanh",
|
||||
summary_last_dropout=0.1,
|
||||
start_n_top=5,
|
||||
end_n_top=5,
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs XLNetConfig.
|
||||
"""
|
||||
super(XLNetConfig, self).__init__(**kwargs)
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@@ -44,24 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--albert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--albert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.albert_config_file,
|
||||
args.pytorch_dump_path)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
||||
|
||||
@@ -24,8 +24,10 @@ import torch
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.bert_config_file,
|
||||
args.pytorch_dump_path)
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
||||
|
||||
@@ -23,7 +23,7 @@ import tensorflow as tf
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
|
||||
"""
|
||||
:param model:BertModel Pytorch model instance to be converted
|
||||
@@ -41,22 +41,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
||||
N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transpose = (
|
||||
"dense.weight",
|
||||
"attention.self.query",
|
||||
"attention.self.key",
|
||||
"attention.self.value"
|
||||
)
|
||||
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||
|
||||
var_map = (
|
||||
('layer.', 'layer_'),
|
||||
('word_embeddings.weight', 'word_embeddings'),
|
||||
('position_embeddings.weight', 'position_embeddings'),
|
||||
('token_type_embeddings.weight', 'token_type_embeddings'),
|
||||
('.', '/'),
|
||||
('LayerNorm/weight', 'LayerNorm/gamma'),
|
||||
('LayerNorm/bias', 'LayerNorm/beta'),
|
||||
('weight', 'kernel')
|
||||
("layer.", "layer_"),
|
||||
("word_embeddings.weight", "word_embeddings"),
|
||||
("position_embeddings.weight", "position_embeddings"),
|
||||
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||
(".", "/"),
|
||||
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||
("LayerNorm/bias", "LayerNorm/beta"),
|
||||
("weight", "kernel"),
|
||||
)
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
@@ -64,12 +59,12 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def to_tf_var_name(name:str):
|
||||
def to_tf_var_name(name: str):
|
||||
for patt, repl in iter(var_map):
|
||||
name = name.replace(patt, repl)
|
||||
return 'bert/{}'.format(name)
|
||||
return "bert/{}".format(name)
|
||||
|
||||
def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session):
|
||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||
session.run(tf.variables_initializer([tf_var]))
|
||||
@@ -94,36 +89,21 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="model name e.g. bert-base-uncased")
|
||||
parser.add_argument("--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Directory containing pytorch model")
|
||||
parser.add_argument("--pytorch_model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory in which to save tensorflow model")
|
||||
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased")
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||
)
|
||||
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path),
|
||||
cache_dir=args.cache_dir
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(
|
||||
model=model,
|
||||
ckpt_dir=args.tf_cache_dir,
|
||||
model_name=args.model_name
|
||||
)
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,12 +21,10 @@ from io import open
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
||||
GPT2Config,
|
||||
GPT2Model,
|
||||
load_tf_weights_in_gpt2)
|
||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
||||
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
@@ -54,22 +52,18 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--gpt2_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--gpt2_config_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument(
|
||||
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpt2_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
|
||||
args.gpt2_config_file,
|
||||
args.pytorch_dump_folder_path)
|
||||
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
|
||||
|
||||
@@ -21,12 +21,10 @@ from io import open
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTModel,
|
||||
load_tf_weights_in_openai_gpt)
|
||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
@@ -54,22 +52,24 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--openai_checkpoint_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--openai_config_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument(
|
||||
"--openai_checkpoint_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the TensorFlow checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
|
||||
args.openai_config_file,
|
||||
args.pytorch_dump_folder_path)
|
||||
convert_openai_checkpoint_to_pytorch(
|
||||
args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
|
||||
)
|
||||
|
||||
@@ -24,82 +24,270 @@ import tensorflow as tf
|
||||
|
||||
from transformers import is_torch_available, cached_path
|
||||
|
||||
from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
||||
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
from transformers import (
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
BertConfig,
|
||||
TFBertForPreTraining,
|
||||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2Config,
|
||||
TFGPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNetConfig,
|
||||
TFXLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLMConfig,
|
||||
TFXLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TransfoXLConfig,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OpenAIGPTConfig,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RobertaConfig,
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig,
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TFDistilBertForSequenceClassification,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRLConfig,
|
||||
TFCTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
AlbertConfig,
|
||||
TFAlbertForMaskedLM,
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5Config,
|
||||
TFT5WithLMHeadModel,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from transformers import (
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertForSequenceClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
else:
|
||||
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
||||
None, None, None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
None, None, None,
|
||||
None, None, None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
None, None)
|
||||
(
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
) = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
"bert": (
|
||||
BertConfig,
|
||||
TFBertForPreTraining,
|
||||
BertForPreTraining,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": (
|
||||
BertConfig,
|
||||
TFBertForQuestionAnswering,
|
||||
BertForQuestionAnswering,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": (
|
||||
BertConfig,
|
||||
TFBertForQuestionAnswering,
|
||||
BertForQuestionAnswering,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"bert-base-cased-finetuned-mrpc": (
|
||||
BertConfig,
|
||||
TFBertForSequenceClassification,
|
||||
BertForSequenceClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"gpt2": (
|
||||
GPT2Config,
|
||||
TFGPT2LMHeadModel,
|
||||
GPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"xlnet": (
|
||||
XLNetConfig,
|
||||
TFXLNetLMHeadModel,
|
||||
XLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"xlm": (
|
||||
XLMConfig,
|
||||
TFXLMWithLMHeadModel,
|
||||
XLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"transfo-xl": (
|
||||
TransfoXLConfig,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"openai-gpt": (
|
||||
OpenAIGPTConfig,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"roberta": (
|
||||
RobertaConfig,
|
||||
TFRobertaForMaskedLM,
|
||||
RobertaForMaskedLM,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"roberta-large-mnli": (
|
||||
RobertaConfig,
|
||||
TFRobertaForSequenceClassification,
|
||||
RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"distilbert": (
|
||||
DistilBertConfig,
|
||||
TFDistilBertForMaskedLM,
|
||||
DistilBertForMaskedLM,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"distilbert-base-uncased-distilled-squad": (
|
||||
DistilBertConfig,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"distilbert-base-uncased-distilled-squad": (
|
||||
DistilBertConfig,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"ctrl": (
|
||||
CTRLConfig,
|
||||
TFCTRLLMHeadModel,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"albert": (
|
||||
AlbertConfig,
|
||||
TFAlbertForMaskedLM,
|
||||
AlbertForMaskedLM,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"t5": (
|
||||
T5Config,
|
||||
TFT5WithLMHeadModel,
|
||||
T5WithLMHeadModel,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
}
|
||||
|
||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
||||
|
||||
def convert_pt_checkpoint_to_tf(
|
||||
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
|
||||
):
|
||||
if model_type not in MODEL_CLASSES:
|
||||
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||
|
||||
@@ -116,17 +304,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
if pytorch_checkpoint_path in aws_model_maps:
|
||||
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
|
||||
pytorch_checkpoint_path = cached_path(
|
||||
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
|
||||
)
|
||||
# Load PyTorch checkpoint in tf2 model:
|
||||
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
|
||||
if compare_with_pt_model:
|
||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||
|
||||
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
|
||||
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
|
||||
config=config,
|
||||
state_dict=state_dict)
|
||||
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
|
||||
pt_model = pt_model_class.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_model.dummy_inputs)
|
||||
@@ -139,11 +329,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save TensorFlow model to {}".format(tf_dump_path))
|
||||
tf_model.save_weights(tf_dump_path, save_format='h5')
|
||||
tf_model.save_weights(tf_dump_path, save_format="h5")
|
||||
|
||||
|
||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
||||
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
|
||||
def convert_all_pt_checkpoints_to_tf(
|
||||
args_model_type,
|
||||
tf_dump_path,
|
||||
model_shortcut_names_or_path=None,
|
||||
config_shortcut_names_or_path=None,
|
||||
compare_with_pt_model=False,
|
||||
use_cached_models=False,
|
||||
remove_cached_files=False,
|
||||
only_convert_finetuned_models=False,
|
||||
):
|
||||
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
||||
|
||||
if args_model_type is None:
|
||||
@@ -156,7 +354,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
|
||||
print("=" * 100)
|
||||
if model_type not in MODEL_CLASSES:
|
||||
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
|
||||
raise ValueError(
|
||||
"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
|
||||
)
|
||||
|
||||
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||
|
||||
@@ -166,9 +366,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
config_shortcut_names_or_path = model_shortcut_names_or_path
|
||||
|
||||
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
|
||||
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1):
|
||||
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
|
||||
):
|
||||
print("-" * 100)
|
||||
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name:
|
||||
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
|
||||
if not only_convert_finetuned_models:
|
||||
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
|
||||
continue
|
||||
@@ -176,7 +377,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
elif only_convert_finetuned_models:
|
||||
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
|
||||
continue
|
||||
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type))
|
||||
print(
|
||||
" Converting checkpoint {}/{}: {} - model_type {}".format(
|
||||
i, len(aws_config_map), model_shortcut_name, model_type
|
||||
)
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
if config_shortcut_name in aws_config_map:
|
||||
@@ -190,13 +395,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
||||
|
||||
if os.path.isfile(model_shortcut_name):
|
||||
model_shortcut_name = 'converted_model'
|
||||
model_shortcut_name = "converted_model"
|
||||
|
||||
convert_pt_checkpoint_to_tf(model_type=model_type,
|
||||
pytorch_checkpoint_path=model_file,
|
||||
config_file=config_file,
|
||||
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
||||
compare_with_pt_model=compare_with_pt_model)
|
||||
convert_pt_checkpoint_to_tf(
|
||||
model_type=model_type,
|
||||
pytorch_checkpoint_path=model_file,
|
||||
config_file=config_file,
|
||||
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
|
||||
compare_with_pt_model=compare_with_pt_model,
|
||||
)
|
||||
if remove_cached_files:
|
||||
os.remove(config_file)
|
||||
os.remove(model_file)
|
||||
@@ -205,39 +412,47 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output Tensorflow dump file.")
|
||||
parser.add_argument("--model_type",
|
||||
default = None,
|
||||
type = str,
|
||||
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys())))
|
||||
parser.add_argument("--pytorch_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
|
||||
"If not given, will download and convert all the checkpoints from AWS.")
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
help = "The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture. If not given and "
|
||||
"--pytorch_checkpoint_path is not given or is a shortcut name"
|
||||
"use the configuration associated to the shortcut name on the AWS")
|
||||
parser.add_argument("--compare_with_pt_model",
|
||||
action='store_true',
|
||||
help = "Compare Tensorflow and PyTorch model predictions.")
|
||||
parser.add_argument("--use_cached_models",
|
||||
action='store_true',
|
||||
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
||||
parser.add_argument("--remove_cached_files",
|
||||
action='store_true',
|
||||
help = "Remove pytorch models after conversion (save memory when converting in batches).")
|
||||
parser.add_argument("--only_convert_finetuned_models",
|
||||
action='store_true',
|
||||
help = "Only convert finetuned models.")
|
||||
parser.add_argument(
|
||||
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(
|
||||
list(MODEL_CLASSES.keys())
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
|
||||
"If not given, will download and convert all the checkpoints from AWS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture. If not given and "
|
||||
"--pytorch_checkpoint_path is not given or is a shortcut name"
|
||||
"use the configuration associated to the shortcut name on the AWS",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cached_models",
|
||||
action="store_true",
|
||||
help="Use cached models if possible instead of updating to latest checkpoint versions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_cached_files",
|
||||
action="store_true",
|
||||
help="Remove pytorch models after conversion (save memory when converting in batches).",
|
||||
)
|
||||
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# if args.pytorch_checkpoint_path is not None:
|
||||
@@ -248,11 +463,15 @@ if __name__ == "__main__":
|
||||
# compare_with_pt_model=args.compare_with_pt_model,
|
||||
# use_cached_models=args.use_cached_models)
|
||||
# else:
|
||||
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
||||
args.tf_dump_path,
|
||||
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
|
||||
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||
compare_with_pt_model=args.compare_with_pt_model,
|
||||
use_cached_models=args.use_cached_models,
|
||||
remove_cached_files=args.remove_cached_files,
|
||||
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
||||
convert_all_pt_checkpoints_to_tf(
|
||||
args.model_type.lower() if args.model_type is not None else None,
|
||||
args.tf_dump_path,
|
||||
model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
|
||||
if args.pytorch_checkpoint_path is not None
|
||||
else None,
|
||||
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||
compare_with_pt_model=args.compare_with_pt_model,
|
||||
use_cached_models=args.use_cached_models,
|
||||
remove_cached_files=args.remove_cached_files,
|
||||
only_convert_finetuned_models=args.only_convert_finetuned_models,
|
||||
)
|
||||
|
||||
@@ -30,20 +30,27 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
|
||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from transformers.modeling_bert import (BertConfig, BertEncoder,
|
||||
BertIntermediate, BertLayer,
|
||||
BertModel, BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput)
|
||||
from transformers.modeling_roberta import (RobertaEmbeddings,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaModel)
|
||||
from transformers.modeling_bert import (
|
||||
BertConfig,
|
||||
BertEncoder,
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.modeling_roberta import (
|
||||
RobertaEmbeddings,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaModel,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
|
||||
@@ -61,7 +68,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
intermediate_size=roberta.args.encoder_ffn_embed_dim,
|
||||
max_position_embeddings=514,
|
||||
type_vocab_size=1,
|
||||
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
||||
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
||||
)
|
||||
if classification_head:
|
||||
config.num_labels = roberta.args.num_classes
|
||||
@@ -74,7 +81,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
# Embeddings
|
||||
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
|
||||
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
|
||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
|
||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
|
||||
model.roberta.embeddings.token_type_embeddings.weight
|
||||
) # just zero them out b/c RoBERTa doesn't use them.
|
||||
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
|
||||
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
|
||||
|
||||
@@ -85,11 +94,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
|
||||
### self attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
assert(
|
||||
roberta_layer.self_attn.k_proj.weight.data.shape == \
|
||||
roberta_layer.self_attn.q_proj.weight.data.shape == \
|
||||
roberta_layer.self_attn.v_proj.weight.data.shape == \
|
||||
torch.Size((config.hidden_size, config.hidden_size))
|
||||
assert (
|
||||
roberta_layer.self_attn.k_proj.weight.data.shape
|
||||
== roberta_layer.self_attn.q_proj.weight.data.shape
|
||||
== roberta_layer.self_attn.v_proj.weight.data.shape
|
||||
== torch.Size((config.hidden_size, config.hidden_size))
|
||||
)
|
||||
|
||||
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
|
||||
@@ -101,9 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
|
||||
### self-attention output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
assert(
|
||||
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||
)
|
||||
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
||||
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
|
||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||
@@ -111,17 +118,13 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
|
||||
### intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
assert(
|
||||
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||
)
|
||||
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||
intermediate.dense.weight = roberta_layer.fc1.weight
|
||||
intermediate.dense.bias = roberta_layer.fc1.bias
|
||||
|
||||
### output
|
||||
bert_output: BertOutput = layer.output
|
||||
assert(
|
||||
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||
)
|
||||
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||
bert_output.dense.weight = roberta_layer.fc2.weight
|
||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||
@@ -129,10 +132,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
#### end of layer
|
||||
|
||||
if classification_head:
|
||||
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
|
||||
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
|
||||
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
|
||||
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
|
||||
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
|
||||
model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias
|
||||
model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight
|
||||
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
|
||||
else:
|
||||
# LM Head
|
||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
||||
@@ -143,21 +146,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
|
||||
# Let's check that we get the same results.
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
|
||||
our_output = model(input_ids)[0]
|
||||
if classification_head:
|
||||
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
|
||||
their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids))
|
||||
else:
|
||||
their_output = roberta.model(input_ids)[0]
|
||||
print(our_output.shape, their_output.shape)
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||
print(
|
||||
"Do both models output the same tensors?",
|
||||
"🔥" if success else "💩"
|
||||
)
|
||||
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
@@ -169,23 +169,16 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--roberta_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the official PyTorch dump.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--classification_head",
|
||||
action = "store_true",
|
||||
help = "Whether to convert a final classification head.")
|
||||
parser.add_argument(
|
||||
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classification_head", action="store_true", help="Whether to convert a final classification head."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_roberta_checkpoint_to_pytorch(
|
||||
args.roberta_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.classification_head
|
||||
args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
|
||||
)
|
||||
|
||||
|
||||
@@ -24,8 +24,10 @@ import torch
|
||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained T5 model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained T5 model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.config_file,
|
||||
args.pytorch_dump_path)
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||
|
||||
@@ -26,9 +26,8 @@ import torch
|
||||
import transformers.tokenization_transfo_xl as data_utils
|
||||
|
||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
|
||||
load_tf_weights_in_transfo_xl)
|
||||
from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
|
||||
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
@@ -36,32 +35,33 @@ else:
|
||||
import pickle
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# We do this to be able to load python 2 datasets pickles
|
||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||
sys.modules['data_utils'] = data_utils
|
||||
sys.modules['vocabulary'] = data_utils
|
||||
sys.modules["data_utils"] = data_utils
|
||||
sys.modules["vocabulary"] = data_utils
|
||||
|
||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
transfo_xl_config_file,
|
||||
pytorch_dump_folder_path,
|
||||
transfo_xl_dataset_file):
|
||||
|
||||
def convert_transfo_xl_checkpoint_to_pytorch(
|
||||
tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
|
||||
):
|
||||
if transfo_xl_dataset_file:
|
||||
# Convert a pre-processed corpus (see original TensorFlow repo)
|
||||
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||
corpus = pickle.load(fp, encoding="latin1")
|
||||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
|
||||
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
||||
corpus_vocab_dict = corpus.vocab.__dict__
|
||||
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
|
||||
|
||||
corpus_dict_no_vocab = corpus.__dict__
|
||||
corpus_dict_no_vocab.pop('vocab', None)
|
||||
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
|
||||
corpus_dict_no_vocab.pop("vocab", None)
|
||||
pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
|
||||
print("Save dataset to {}".format(pytorch_dataset_dump_path))
|
||||
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
||||
|
||||
@@ -92,26 +92,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional path to a TensorFlow checkpoint path to be converted.")
|
||||
parser.add_argument("--transfo_xl_config_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--transfo_xl_dataset_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional dataset file to be converted in a vocabulary.")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the folder to store the PyTorch model or dataset/vocab.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional path to a TensorFlow checkpoint path to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transfo_xl_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transfo_xl_dataset_file",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional dataset file to be converted in a vocabulary.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.transfo_xl_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.transfo_xl_dataset_file)
|
||||
convert_transfo_xl_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path,
|
||||
args.transfo_xl_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.transfo_xl_dataset_file,
|
||||
)
|
||||
|
||||
@@ -27,32 +27,34 @@ from transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||
# Load checkpoint
|
||||
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
|
||||
chkpt = torch.load(xlm_checkpoint_path, map_location="cpu")
|
||||
|
||||
state_dict = chkpt['model']
|
||||
state_dict = chkpt["model"]
|
||||
|
||||
# We have the base model one level deeper than the original XLM repository
|
||||
two_levels_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'pred_layer' in k:
|
||||
if "pred_layer" in k:
|
||||
two_levels_state_dict[k] = v
|
||||
else:
|
||||
two_levels_state_dict['transformer.' + k] = v
|
||||
two_levels_state_dict["transformer." + k] = v
|
||||
|
||||
config = chkpt['params']
|
||||
config = chkpt["params"]
|
||||
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
|
||||
|
||||
vocab = chkpt['dico_word2id']
|
||||
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items())
|
||||
vocab = chkpt["dico_word2id"]
|
||||
vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"]
|
||||
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(two_levels_state_dict, pytorch_weights_dump_path)
|
||||
@@ -69,15 +71,11 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--xlm_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the official PyTorch dump.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
|
||||
|
||||
@@ -22,11 +22,15 @@ import os
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
||||
XLNetConfig,
|
||||
XLNetLMHeadModel, XLNetForQuestionAnswering,
|
||||
XLNetForSequenceClassification,
|
||||
load_tf_weights_in_xlnet)
|
||||
from transformers import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
XLNetConfig,
|
||||
XLNetLMHeadModel,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetForSequenceClassification,
|
||||
load_tf_weights_in_xlnet,
|
||||
)
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
"cola": 2,
|
||||
@@ -41,9 +45,13 @@ GLUE_TASKS_NUM_LABELS = {
|
||||
}
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
|
||||
|
||||
def convert_xlnet_checkpoint_to_pytorch(
|
||||
tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None
|
||||
):
|
||||
# Initialise PyTorch model
|
||||
config = XLNetConfig.from_json_file(bert_config_file)
|
||||
|
||||
@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
|
||||
config.finetuning_task = finetuning_task
|
||||
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
|
||||
model = XLNetForSequenceClassification(config)
|
||||
elif 'squad' in finetuning_task:
|
||||
elif "squad" in finetuning_task:
|
||||
config.finetuning_task = finetuning_task
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
else:
|
||||
@@ -75,30 +83,33 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--xlnet_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained XLNet model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
|
||||
parser.add_argument("--finetuning_task",
|
||||
default = None,
|
||||
type = str,
|
||||
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xlnet_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained XLNet model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the folder to store the PyTorch model or dataset/vocab.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetuning_task",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Name of a task on which the XLNet TensorFloaw model was fine-tuned",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.xlnet_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.finetuning_task)
|
||||
convert_xlnet_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task
|
||||
)
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor
|
||||
from .processors import (
|
||||
InputExample,
|
||||
InputFeatures,
|
||||
DataProcessor,
|
||||
SquadFeatures,
|
||||
SingleSentenceClassificationProcessor,
|
||||
)
|
||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
|
||||
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
||||
|
||||
from .metrics import is_sklearn_available
|
||||
|
||||
if is_sklearn_available():
|
||||
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
||||
|
||||
@@ -23,20 +23,22 @@ logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
|
||||
_has_sklearn = True
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
||||
_has_sklearn = False
|
||||
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
|
||||
|
||||
if _has_sklearn:
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
@@ -46,7 +48,6 @@ if _has_sklearn:
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
@@ -56,7 +57,6 @@ if _has_sklearn:
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def glue_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
@@ -82,7 +82,6 @@ if _has_sklearn:
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
|
||||
def xnli_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "xnli":
|
||||
|
||||
@@ -24,19 +24,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds):
|
||||
|
||||
for example in examples:
|
||||
qas_id = example.qas_id
|
||||
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])]
|
||||
gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
|
||||
|
||||
if not gold_answers:
|
||||
# For unanswerable questions, only correct answer is empty string
|
||||
gold_answers = ['']
|
||||
gold_answers = [""]
|
||||
|
||||
if qas_id not in preds:
|
||||
print('Missing prediction for %s' % qas_id)
|
||||
print("Missing prediction for %s" % qas_id)
|
||||
continue
|
||||
|
||||
prediction = preds[qas_id]
|
||||
@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
||||
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||
if not qid_list:
|
||||
total = len(exact_scores)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores.values()) / total),
|
||||
('f1', 100.0 * sum(f1_scores.values()) / total),
|
||||
('total', total),
|
||||
])
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("exact", 100.0 * sum(exact_scores.values()) / total),
|
||||
("f1", 100.0 * sum(f1_scores.values()) / total),
|
||||
("total", total),
|
||||
]
|
||||
)
|
||||
else:
|
||||
total = len(qid_list)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
('total', total),
|
||||
])
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
("total", total),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def merge_eval(main_eval, new_eval, prefix):
|
||||
for k in new_eval:
|
||||
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
||||
main_eval["%s_%s" % (prefix, k)] = new_eval[k]
|
||||
|
||||
|
||||
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||
@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||
|
||||
|
||||
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
|
||||
preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
|
||||
preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
main_eval['has_ans_exact'] = has_ans_exact
|
||||
main_eval['has_ans_f1'] = has_ans_f1
|
||||
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval["best_exact"] = best_exact
|
||||
main_eval["best_exact_thresh"] = exact_thresh
|
||||
main_eval["best_f1"] = best_f1
|
||||
main_eval["best_f1_thresh"] = f1_thresh
|
||||
main_eval["has_ans_exact"] = has_ans_exact
|
||||
main_eval["has_ans_f1"] = has_ans_f1
|
||||
|
||||
|
||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||
@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
|
||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
main_eval["best_exact"] = best_exact
|
||||
main_eval["best_exact_thresh"] = exact_thresh
|
||||
main_eval["best_f1"] = best_f1
|
||||
main_eval["best_f1_thresh"] = f1_thresh
|
||||
|
||||
|
||||
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
||||
@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_
|
||||
|
||||
exact, f1 = get_raw_scores(examples, preds)
|
||||
|
||||
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
||||
exact_threshold = apply_no_ans_threshold(
|
||||
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
|
||||
)
|
||||
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
||||
|
||||
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
||||
|
||||
if has_answer_qids:
|
||||
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
||||
merge_eval(evaluation, has_ans_eval, 'HasAns')
|
||||
merge_eval(evaluation, has_ans_eval, "HasAns")
|
||||
|
||||
if no_answer_qids:
|
||||
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
||||
merge_eval(evaluation, no_ans_eval, 'NoAns')
|
||||
merge_eval(evaluation, no_ans_eval, "NoAns")
|
||||
|
||||
if no_answer_probs:
|
||||
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
||||
@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
if verbose_logging:
|
||||
logger.info(
|
||||
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||
logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
if verbose_logging:
|
||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
||||
orig_ns_text, tok_ns_text)
|
||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
|
||||
return orig_text
|
||||
|
||||
# We then project the characters in `pred_text` back to `orig_text` using
|
||||
@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
logger.info("Couldn't map end position")
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
@@ -393,8 +397,8 @@ def compute_predictions_logits(
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
||||
)
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
@@ -447,7 +451,9 @@ def compute_predictions_logits(
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
end_logit=result.end_logits[end_index],
|
||||
)
|
||||
)
|
||||
if version_2_with_negative:
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
@@ -455,14 +461,14 @@ def compute_predictions_logits(
|
||||
start_index=0,
|
||||
end_index=0,
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_logit + x.end_logit),
|
||||
reverse=True)
|
||||
end_logit=null_end_logit,
|
||||
)
|
||||
)
|
||||
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
||||
)
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
@@ -471,10 +477,10 @@ def compute_predictions_logits(
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
if pred.start_index > 0: # this is a non-null prediction
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
||||
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
|
||||
@@ -498,31 +504,21 @@ def compute_predictions_logits(
|
||||
final_text = ""
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit))
|
||||
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
||||
# if we didn't include the empty option in the n-best, include it
|
||||
if version_2_with_negative:
|
||||
if "" not in seen_predictions:
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text="",
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
||||
|
||||
# In very rare edge cases we could only have single null prediction.
|
||||
# So we just create a nonce prediction in this case to avoid failure.
|
||||
if len(nbest) == 1:
|
||||
nbest.insert(0,
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
assert len(nbest) >= 1
|
||||
|
||||
@@ -551,8 +547,7 @@ def compute_predictions_logits(
|
||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||
else:
|
||||
# predict "" iff the null score - the score of best non-null > threshold
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||
best_non_null_entry.end_logit)
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example.qas_id] = ""
|
||||
@@ -586,7 +581,7 @@ def compute_predictions_log_probs(
|
||||
end_n_top,
|
||||
version_2_with_negative,
|
||||
tokenizer,
|
||||
verbose_logging
|
||||
verbose_logging,
|
||||
):
|
||||
""" XLNet write prediction logic (more complex than Bert's).
|
||||
Write final predictions to the json file and log-odds of null if needed.
|
||||
@@ -594,12 +589,12 @@ def compute_predictions_log_probs(
|
||||
Requires utils_squad_evaluate.py
|
||||
"""
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index",
|
||||
"start_log_prob", "end_log_prob"])
|
||||
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
||||
)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
||||
)
|
||||
|
||||
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
@@ -663,12 +658,13 @@ def compute_predictions_log_probs(
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_log_prob=start_log_prob,
|
||||
end_log_prob=end_log_prob))
|
||||
end_log_prob=end_log_prob,
|
||||
)
|
||||
)
|
||||
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
||||
reverse=True)
|
||||
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
||||
)
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
@@ -688,10 +684,10 @@ def compute_predictions_log_probs(
|
||||
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||
|
||||
# Previously used Bert untokenizer
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
|
||||
# Clean whitespace
|
||||
@@ -704,8 +700,7 @@ def compute_predictions_log_probs(
|
||||
else:
|
||||
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case,
|
||||
verbose_logging)
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
||||
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
@@ -713,17 +708,13 @@ def compute_predictions_log_probs(
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_log_prob=pred.start_log_prob,
|
||||
end_log_prob=pred.end_log_prob))
|
||||
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
||||
)
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="", start_log_prob=-1e6,
|
||||
end_log_prob=-1e6))
|
||||
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
|
||||
@@ -27,15 +27,18 @@ if is_tf_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def glue_convert_examples_to_features(examples, tokenizer,
|
||||
max_length=512,
|
||||
task=None,
|
||||
label_list=None,
|
||||
output_mode=None,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
def glue_convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
max_length=512,
|
||||
task=None,
|
||||
label_list=None,
|
||||
output_mode=None,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True,
|
||||
):
|
||||
"""
|
||||
Loads a data file into a list of ``InputFeatures``
|
||||
|
||||
@@ -82,12 +85,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
||||
example = processor.get_example_from_tensor_dict(example)
|
||||
example = processor.tfds_map(example)
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
example.text_a,
|
||||
example.text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
|
||||
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
@@ -106,8 +104,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
||||
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
||||
|
||||
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
|
||||
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
|
||||
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
|
||||
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
|
||||
len(attention_mask), max_length
|
||||
)
|
||||
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
|
||||
len(token_type_ids), max_length
|
||||
)
|
||||
|
||||
if output_mode == "classification":
|
||||
label = label_map[example.label]
|
||||
@@ -125,28 +127,36 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
||||
logger.info("label: %s (id = %d)" % (example.label, label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
label=label))
|
||||
InputFeatures(
|
||||
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
|
||||
)
|
||||
)
|
||||
|
||||
if is_tf_available() and is_tf_dataset:
|
||||
|
||||
def gen():
|
||||
for ex in features:
|
||||
yield ({'input_ids': ex.input_ids,
|
||||
'attention_mask': ex.attention_mask,
|
||||
'token_type_ids': ex.token_type_ids},
|
||||
ex.label)
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
},
|
||||
ex.label,
|
||||
)
|
||||
|
||||
return tf.data.Dataset.from_generator(gen,
|
||||
({'input_ids': tf.int32,
|
||||
'attention_mask': tf.int32,
|
||||
'token_type_ids': tf.int32},
|
||||
tf.int64),
|
||||
({'input_ids': tf.TensorShape([None]),
|
||||
'attention_mask': tf.TensorShape([None]),
|
||||
'token_type_ids': tf.TensorShape([None])},
|
||||
tf.TensorShape([])))
|
||||
return tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
|
||||
(
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
},
|
||||
tf.TensorShape([]),
|
||||
),
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
@@ -156,21 +166,21 @@ class MrpcProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -186,8 +196,7 @@ class MrpcProcessor(DataProcessor):
|
||||
text_a = line[3]
|
||||
text_b = line[4]
|
||||
label = line[0]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -196,21 +205,20 @@ class MnliProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['premise'].numpy().decode('utf-8'),
|
||||
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["premise"].numpy().decode("utf-8"),
|
||||
tensor_dict["hypothesis"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
|
||||
"dev_matched")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -226,8 +234,7 @@ class MnliProcessor(DataProcessor):
|
||||
text_a = line[8]
|
||||
text_b = line[9]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -236,9 +243,7 @@ class MnliMismatchedProcessor(MnliProcessor):
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
|
||||
"dev_matched")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
|
||||
|
||||
|
||||
class ColaProcessor(DataProcessor):
|
||||
@@ -246,20 +251,20 @@ class ColaProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
||||
None,
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["sentence"].numpy().decode("utf-8"),
|
||||
None,
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -272,8 +277,7 @@ class ColaProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = line[3]
|
||||
label = line[1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -282,20 +286,20 @@ class Sst2Processor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
||||
None,
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["sentence"].numpy().decode("utf-8"),
|
||||
None,
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -310,8 +314,7 @@ class Sst2Processor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = line[0]
|
||||
label = line[1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -320,20 +323,20 @@ class StsbProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -349,8 +352,7 @@ class StsbProcessor(DataProcessor):
|
||||
text_a = line[7]
|
||||
text_b = line[8]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -359,20 +361,20 @@ class QqpProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['question1'].numpy().decode('utf-8'),
|
||||
tensor_dict['question2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["question1"].numpy().decode("utf-8"),
|
||||
tensor_dict["question2"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -391,8 +393,7 @@ class QqpProcessor(DataProcessor):
|
||||
label = line[5]
|
||||
except IndexError:
|
||||
continue
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -401,21 +402,20 @@ class QnliProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['question'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["question"].numpy().decode("utf-8"),
|
||||
tensor_dict["sentence"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
||||
"dev_matched")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -431,8 +431,7 @@ class QnliProcessor(DataProcessor):
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -441,20 +440,20 @@ class RteProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -470,8 +469,7 @@ class RteProcessor(DataProcessor):
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -480,20 +478,20 @@ class WnliProcessor(DataProcessor):
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
@@ -509,10 +507,10 @@ class WnliProcessor(DataProcessor):
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
glue_tasks_num_labels = {
|
||||
"cola": 2,
|
||||
"mnli": 3,
|
||||
|
||||
@@ -82,8 +82,8 @@ def _is_whitespace(c):
|
||||
return True
|
||||
return False
|
||||
|
||||
def squad_convert_example_to_features(example, max_seq_length,
|
||||
doc_stride, max_query_length, is_training):
|
||||
|
||||
def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
|
||||
features = []
|
||||
if is_training and not example.is_impossible:
|
||||
# Get start and end position
|
||||
@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length,
|
||||
end_position = example.end_position
|
||||
|
||||
# If the answer cannot be found in the text, then skip this example.
|
||||
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)])
|
||||
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
|
||||
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||
@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length,
|
||||
spans = []
|
||||
|
||||
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
|
||||
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
sequence_added_tokens = (
|
||||
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
|
||||
if "roberta" in str(type(tokenizer))
|
||||
else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
)
|
||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||
|
||||
span_doc_tokens = all_doc_tokens
|
||||
@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length,
|
||||
return_overflowing_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
|
||||
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
|
||||
)
|
||||
|
||||
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride,
|
||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
||||
paragraph_len = min(
|
||||
len(all_doc_tokens) - len(spans) * doc_stride,
|
||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id in encoded_dict['input_ids']:
|
||||
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
||||
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
||||
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
||||
else:
|
||||
non_padded_ids = encoded_dict['input_ids']
|
||||
non_padded_ids = encoded_dict["input_ids"]
|
||||
|
||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||
|
||||
@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length,
|
||||
for doc_span_index in range(len(spans)):
|
||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][
|
||||
"truncated_query_with_special_tokens_length"] + j
|
||||
index = (
|
||||
j
|
||||
if tokenizer.padding_side == "left"
|
||||
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
)
|
||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||
|
||||
for span in spans:
|
||||
# Identify the position of the CLS token
|
||||
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
|
||||
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
||||
|
||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = np.array(span['token_type_ids'])
|
||||
p_mask = np.array(span["token_type_ids"])
|
||||
|
||||
p_mask = np.minimum(p_mask, 1)
|
||||
|
||||
@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length,
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
features.append(SquadFeatures(
|
||||
span['input_ids'],
|
||||
span['attention_mask'],
|
||||
span['token_type_ids'],
|
||||
cls_index,
|
||||
p_mask.tolist(),
|
||||
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
||||
unique_id=0,
|
||||
paragraph_len=span['paragraph_len'],
|
||||
token_is_max_context=span["token_is_max_context"],
|
||||
tokens=span["tokens"],
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
|
||||
start_position=start_position,
|
||||
end_position=end_position
|
||||
))
|
||||
features.append(
|
||||
SquadFeatures(
|
||||
span["input_ids"],
|
||||
span["attention_mask"],
|
||||
span["token_type_ids"],
|
||||
cls_index,
|
||||
p_mask.tolist(),
|
||||
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
||||
unique_id=0,
|
||||
paragraph_len=span["paragraph_len"],
|
||||
token_is_max_context=span["token_is_max_context"],
|
||||
tokens=span["tokens"],
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def squad_convert_example_to_features_init(tokenizer_for_convert):
|
||||
global tokenizer
|
||||
tokenizer = tokenizer_for_convert
|
||||
|
||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training,
|
||||
return_dataset=False, threads=1):
|
||||
|
||||
def squad_convert_examples_to_features(
|
||||
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1
|
||||
):
|
||||
"""
|
||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||
@@ -283,13 +294,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
features = []
|
||||
threads = min(threads, cpu_count())
|
||||
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
||||
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length,
|
||||
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training)
|
||||
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features'))
|
||||
annotate_ = partial(
|
||||
squad_convert_example_to_features,
|
||||
max_seq_length=max_seq_length,
|
||||
doc_stride=doc_stride,
|
||||
max_query_length=max_query_length,
|
||||
is_training=is_training,
|
||||
)
|
||||
features = list(
|
||||
tqdm(
|
||||
p.imap(annotate_, examples, chunksize=32),
|
||||
total=len(examples),
|
||||
desc="convert squad examples to features",
|
||||
)
|
||||
)
|
||||
new_features = []
|
||||
unique_id = 1000000000
|
||||
example_index = 0
|
||||
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'):
|
||||
for example_features in tqdm(features, total=len(features), desc="add example index and unique id"):
|
||||
if not example_features:
|
||||
continue
|
||||
for example_feature in example_features:
|
||||
@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
example_index += 1
|
||||
features = new_features
|
||||
del new_features
|
||||
if return_dataset == 'pt':
|
||||
if return_dataset == "pt":
|
||||
if not is_torch_available():
|
||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||
|
||||
@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
}, {
|
||||
},
|
||||
{
|
||||
"start_position": ex.start_position,
|
||||
"end_position": ex.end_position,
|
||||
"cls_index": ex.cls_index,
|
||||
"p_mask": ex.p_mask,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...file_utils import is_tf_available, is_torch_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""
|
||||
A single training/test example for simple sequence classification.
|
||||
@@ -37,6 +38,7 @@ class InputExample(object):
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
@@ -99,14 +101,15 @@ class DataProcessor(object):
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
line = list(unicode(cell, "utf-8") for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
""" Generic processor for a single sentence classification data set."""
|
||||
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
|
||||
|
||||
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
|
||||
self.labels = [] if labels is None else labels
|
||||
self.examples = [] if examples is None else examples
|
||||
self.mode = mode
|
||||
@@ -117,22 +120,24 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return SingleSentenceClassificationProcessor(labels=self.labels,
|
||||
examples=self.examples[idx])
|
||||
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
|
||||
return self.examples[idx]
|
||||
|
||||
@classmethod
|
||||
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
|
||||
column_id=None, skip_first_row=False, **kwargs):
|
||||
def create_from_csv(
|
||||
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
|
||||
):
|
||||
processor = cls(**kwargs)
|
||||
processor.add_examples_from_csv(file_name,
|
||||
split_name=split_name,
|
||||
column_label=column_label,
|
||||
column_text=column_text,
|
||||
column_id=column_id,
|
||||
skip_first_row=skip_first_row,
|
||||
overwrite_labels=True,
|
||||
overwrite_examples=True)
|
||||
processor.add_examples_from_csv(
|
||||
file_name,
|
||||
split_name=split_name,
|
||||
column_label=column_label,
|
||||
column_text=column_text,
|
||||
column_id=column_id,
|
||||
skip_first_row=skip_first_row,
|
||||
overwrite_labels=True,
|
||||
overwrite_examples=True,
|
||||
)
|
||||
return processor
|
||||
|
||||
@classmethod
|
||||
@@ -141,8 +146,17 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
||||
return processor
|
||||
|
||||
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
|
||||
skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
|
||||
def add_examples_from_csv(
|
||||
self,
|
||||
file_name,
|
||||
split_name="",
|
||||
column_label=0,
|
||||
column_text=1,
|
||||
column_id=None,
|
||||
skip_first_row=False,
|
||||
overwrite_labels=False,
|
||||
overwrite_examples=False,
|
||||
):
|
||||
lines = self._read_tsv(file_name)
|
||||
if skip_first_row:
|
||||
lines = lines[1:]
|
||||
@@ -158,10 +172,13 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
|
||||
ids.append(guid)
|
||||
|
||||
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
|
||||
return self.add_examples(
|
||||
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
|
||||
)
|
||||
|
||||
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
|
||||
overwrite_labels=False, overwrite_examples=False):
|
||||
def add_examples(
|
||||
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
|
||||
):
|
||||
assert labels is None or len(texts_or_text_and_labels) == len(labels)
|
||||
assert ids is None or len(texts_or_text_and_labels) == len(ids)
|
||||
if ids is None:
|
||||
@@ -192,13 +209,15 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
|
||||
return self.examples
|
||||
|
||||
def get_features(self,
|
||||
tokenizer,
|
||||
max_length=None,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
mask_padding_with_zero=True,
|
||||
return_tensors=None):
|
||||
def get_features(
|
||||
self,
|
||||
tokenizer,
|
||||
max_length=None,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
mask_padding_with_zero=True,
|
||||
return_tensors=None,
|
||||
):
|
||||
"""
|
||||
Convert examples in a list of ``InputFeatures``
|
||||
|
||||
@@ -231,9 +250,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
logger.info("Tokenizing example %d", ex_index)
|
||||
|
||||
input_ids = tokenizer.encode(
|
||||
example.text_a,
|
||||
add_special_tokens=True,
|
||||
max_length=min(max_length, tokenizer.max_len),
|
||||
example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len),
|
||||
)
|
||||
all_input_ids.append(input_ids)
|
||||
|
||||
@@ -256,8 +273,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
input_ids = input_ids + ([pad_token] * padding_length)
|
||||
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
|
||||
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length)
|
||||
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length)
|
||||
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(
|
||||
len(input_ids), batch_length
|
||||
)
|
||||
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(
|
||||
len(attention_mask), batch_length
|
||||
)
|
||||
|
||||
if self.mode == "classification":
|
||||
label = label_map[example.label]
|
||||
@@ -273,36 +294,31 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
|
||||
logger.info("label: %s (id = %d)" % (example.label, label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
label=label))
|
||||
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
|
||||
|
||||
if return_tensors is None:
|
||||
return features
|
||||
elif return_tensors == 'tf':
|
||||
elif return_tensors == "tf":
|
||||
if not is_tf_available():
|
||||
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
||||
import tensorflow as tf
|
||||
|
||||
def gen():
|
||||
for ex in features:
|
||||
yield ({'input_ids': ex.input_ids,
|
||||
'attention_mask': ex.attention_mask},
|
||||
ex.label)
|
||||
yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
|
||||
|
||||
dataset = tf.data.Dataset.from_generator(gen,
|
||||
({'input_ids': tf.int32,
|
||||
'attention_mask': tf.int32},
|
||||
tf.int64),
|
||||
({'input_ids': tf.TensorShape([None]),
|
||||
'attention_mask': tf.TensorShape([None])},
|
||||
tf.TensorShape([])))
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
|
||||
({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
|
||||
)
|
||||
return dataset
|
||||
elif return_tensors == 'pt':
|
||||
elif return_tensors == "pt":
|
||||
if not is_torch_available():
|
||||
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
if self.mode == "classification":
|
||||
|
||||
@@ -24,11 +24,12 @@ from .utils import DataProcessor, InputExample
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XnliProcessor(DataProcessor):
|
||||
"""Processor for the XNLI dataset.
|
||||
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
|
||||
|
||||
def __init__(self, language, train_language = None):
|
||||
def __init__(self, language, train_language=None):
|
||||
self.language = language
|
||||
self.train_language = train_language
|
||||
|
||||
@@ -40,13 +41,12 @@ class XnliProcessor(DataProcessor):
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % ('train', i)
|
||||
guid = "%s-%s" % ("train", i)
|
||||
text_a = line[0]
|
||||
text_b = line[1]
|
||||
label = "contradiction" if line[2] == "contradictory" else line[2]
|
||||
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
@@ -59,19 +59,19 @@ class XnliProcessor(DataProcessor):
|
||||
language = line[0]
|
||||
if language != self.language:
|
||||
continue
|
||||
guid = "%s-%s" % ('test', i)
|
||||
guid = "%s-%s" % ("test", i)
|
||||
text_a = line[6]
|
||||
text_b = line[7]
|
||||
label = line[1]
|
||||
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
|
||||
xnli_processors = {
|
||||
"xnli": XnliProcessor,
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache.
|
||||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||
Copyright by the AllenNLP authors.
|
||||
"""
|
||||
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import sys
|
||||
import json
|
||||
@@ -29,9 +29,10 @@ from filelock import FileLock
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
os.environ.setdefault('USE_TORCH', 'YES')
|
||||
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
|
||||
os.environ.setdefault("USE_TORCH", "YES")
|
||||
if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
|
||||
import torch
|
||||
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
else:
|
||||
@@ -41,10 +42,11 @@ except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
os.environ.setdefault('USE_TF', 'YES')
|
||||
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
|
||||
os.environ.setdefault("USE_TF", "YES")
|
||||
if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
|
||||
import tensorflow as tf
|
||||
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
|
||||
|
||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
else:
|
||||
@@ -55,12 +57,13 @@ except (ImportError, AssertionError):
|
||||
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
|
||||
torch_cache_home = _get_torch_home()
|
||||
except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv('TORCH_HOME', os.path.join(
|
||||
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
|
||||
default_cache_path = os.path.join(torch_cache_home, 'transformers')
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
@@ -69,19 +72,21 @@ except ImportError:
|
||||
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
|
||||
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
|
||||
)
|
||||
except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
|
||||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
default_cache_path))
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
|
||||
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
)
|
||||
|
||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF2_WEIGHTS_NAME = 'tf_model.h5'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
CONFIG_NAME = "config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
||||
@@ -95,38 +100,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
|
||||
return _tf_available
|
||||
|
||||
|
||||
if not six.PY2:
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = ''.join(docstr) + fn.__doc__
|
||||
fn.__doc__ = "".join(docstr) + fn.__doc__
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
def add_end_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = fn.__doc__ + ''.join(docstr)
|
||||
fn.__doc__ = fn.__doc__ + "".join(docstr)
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
else:
|
||||
# Not possible to update class docstrings on python2
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
def add_end_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ('http', 'https', 's3')
|
||||
return parsed.scheme in ("http", "https", "s3")
|
||||
|
||||
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False):
|
||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
||||
@@ -145,17 +160,17 @@ def url_to_filename(url, etag=None):
|
||||
so that TF 2.0 can identify it as a HDF5 file
|
||||
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
||||
"""
|
||||
url_bytes = url.encode('utf-8')
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode('utf-8')
|
||||
etag_bytes = etag.encode("utf-8")
|
||||
etag_hash = sha256(etag_bytes)
|
||||
filename += '.' + etag_hash.hexdigest()
|
||||
filename += "." + etag_hash.hexdigest()
|
||||
|
||||
if url.endswith('.h5'):
|
||||
filename += '.h5'
|
||||
if url.endswith(".h5"):
|
||||
filename += ".h5"
|
||||
|
||||
return filename
|
||||
|
||||
@@ -174,19 +189,21 @@ def filename_to_url(filename, cache_dir=None):
|
||||
if not os.path.exists(cache_path):
|
||||
raise EnvironmentError("file {} not found".format(cache_path))
|
||||
|
||||
meta_path = cache_path + '.json'
|
||||
meta_path = cache_path + ".json"
|
||||
if not os.path.exists(meta_path):
|
||||
raise EnvironmentError("file {} not found".format(meta_path))
|
||||
|
||||
with open(meta_path, encoding="utf-8") as meta_file:
|
||||
metadata = json.load(meta_file)
|
||||
url = metadata['url']
|
||||
etag = metadata['etag']
|
||||
url = metadata["url"]
|
||||
etag = metadata["etag"]
|
||||
|
||||
return url, etag
|
||||
|
||||
|
||||
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
|
||||
def cached_path(
|
||||
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
|
||||
):
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
@@ -207,13 +224,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
|
||||
|
||||
if is_remote_url(url_or_filename):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
return get_from_cache(url_or_filename, cache_dir=cache_dir,
|
||||
force_download=force_download, proxies=proxies,
|
||||
resume_download=resume_download, user_agent=user_agent)
|
||||
return get_from_cache(
|
||||
url_or_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
return url_or_filename
|
||||
elif urlparse(url_or_filename).scheme == '':
|
||||
elif urlparse(url_or_filename).scheme == "":
|
||||
# File, but it doesn't exist.
|
||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||
else:
|
||||
@@ -273,31 +295,35 @@ def s3_get(url, temp_file, proxies=None):
|
||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join(
|
||||
"{}/{}".format(k, v) for k, v in user_agent.items()
|
||||
)
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, six.string_types):
|
||||
ua += "; "+ user_agent
|
||||
headers = {
|
||||
"user-agent": ua
|
||||
}
|
||||
ua += "; " + user_agent
|
||||
headers = {"user-agent": ua}
|
||||
if resume_size > 0:
|
||||
headers['Range'] = 'bytes=%d-' % (resume_size,)
|
||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get('Content-Length')
|
||||
content_length = response.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size,
|
||||
desc="Downloading", disable=bool(logger.level<=logging.INFO))
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc="Downloading",
|
||||
disable=bool(logger.level <= logging.INFO),
|
||||
)
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
|
||||
def get_from_cache(
|
||||
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
||||
):
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
@@ -326,7 +352,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
||||
etag = None
|
||||
|
||||
if sys.version_info[0] == 2 and etag is not None:
|
||||
etag = etag.decode('utf-8')
|
||||
etag = etag.decode("utf-8")
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
@@ -337,22 +363,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
||||
if not os.path.exists(cache_path) and etag is None:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
||||
if not file.endswith('.json') and not file.endswith('.lock')
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + '.lock'
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + '.incomplete'
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path,'a+b') as f:
|
||||
with open(incomplete_path, "a+b") as f:
|
||||
yield f
|
||||
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
@@ -366,7 +394,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
logger.info(
|
||||
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
|
||||
)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
@@ -383,12 +413,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
||||
os.rename(temp_file.name, cache_path)
|
||||
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
output_string = json.dumps(meta)
|
||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
|
||||
output_string = unicode(output_string, "utf-8") # The beauty of python 2
|
||||
meta_file.write(output_string)
|
||||
|
||||
return cache_path
|
||||
|
||||
@@ -24,13 +24,14 @@ from tqdm import tqdm
|
||||
|
||||
ENDPOINT = "https://huggingface.co"
|
||||
|
||||
|
||||
class S3Obj:
|
||||
def __init__(
|
||||
self,
|
||||
filename, # type: str
|
||||
LastModified, # type: str
|
||||
ETag, # type: str
|
||||
Size, # type: int
|
||||
filename, # type: str
|
||||
LastModified, # type: str
|
||||
ETag, # type: str
|
||||
Size, # type: int
|
||||
**kwargs
|
||||
):
|
||||
self.filename = filename
|
||||
@@ -43,13 +44,13 @@ class PresignedUrl:
|
||||
def __init__(
|
||||
self,
|
||||
write, # type: str
|
||||
access, # type: str
|
||||
type, # type: str
|
||||
access, # type: str
|
||||
type, # type: str
|
||||
**kwargs
|
||||
):
|
||||
self.write = write
|
||||
self.access = access
|
||||
self.type = type # mime-type to send to S3.
|
||||
self.type = type # mime-type to send to S3.
|
||||
|
||||
|
||||
class HfApi:
|
||||
@@ -58,8 +59,8 @@ class HfApi:
|
||||
|
||||
def login(
|
||||
self,
|
||||
username, # type: str
|
||||
password, # type: str
|
||||
username, # type: str
|
||||
password, # type: str
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
@@ -78,8 +79,7 @@ class HfApi:
|
||||
return d["token"]
|
||||
|
||||
def whoami(
|
||||
self,
|
||||
token, # type: str
|
||||
self, token, # type: str
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
@@ -106,11 +106,7 @@ class HfApi:
|
||||
Call HF API to get a presigned url to upload `filename` to S3.
|
||||
"""
|
||||
path = "{}/api/presign".format(self.endpoint)
|
||||
r = requests.post(
|
||||
path,
|
||||
headers={"authorization": "Bearer {}".format(token)},
|
||||
json={"filename": filename},
|
||||
)
|
||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename},)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return PresignedUrl(**d)
|
||||
@@ -133,9 +129,7 @@ class HfApi:
|
||||
pf = TqdmProgressFileReader(f)
|
||||
data = f if pf.total_size > 0 else ""
|
||||
|
||||
r = requests.put(urls.write, data=data, headers={
|
||||
"content-type": urls.type,
|
||||
})
|
||||
r = requests.put(urls.write, data=data, headers={"content-type": urls.type,})
|
||||
r.raise_for_status()
|
||||
pf.close()
|
||||
return urls.access
|
||||
@@ -152,7 +146,6 @@ class HfApi:
|
||||
return [S3Obj(**x) for x in d]
|
||||
|
||||
|
||||
|
||||
class TqdmProgressFileReader:
|
||||
"""
|
||||
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
|
||||
@@ -161,12 +154,12 @@ class TqdmProgressFileReader:
|
||||
see github.com/huggingface/transformers/pull/2078#discussion_r354739608
|
||||
for implementation details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f # type: io.BufferedReader
|
||||
self, f # type: io.BufferedReader
|
||||
):
|
||||
self.f = f
|
||||
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
||||
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
||||
self.pbar = tqdm(total=self.total_size, leave=False)
|
||||
if six.PY3:
|
||||
# does not work unless PY3
|
||||
@@ -182,7 +175,6 @@ class TqdmProgressFileReader:
|
||||
self.pbar.close()
|
||||
|
||||
|
||||
|
||||
class HfFolder:
|
||||
path_token = expanduser("~/.huggingface/token")
|
||||
|
||||
@@ -201,7 +193,7 @@ class HfFolder:
|
||||
if e.errno != os.errno.EEXIST:
|
||||
raise e
|
||||
pass
|
||||
with open(cls.path_token, 'w+') as f:
|
||||
with open(cls.path_token, "w+") as f:
|
||||
f.write(token)
|
||||
|
||||
@classmethod
|
||||
@@ -210,7 +202,7 @@ class HfFolder:
|
||||
Get token or None if not existent.
|
||||
"""
|
||||
try:
|
||||
with open(cls.path_token, 'r') as f:
|
||||
with open(cls.path_token, "r") as f:
|
||||
return f.read()
|
||||
except:
|
||||
# this is too wide. When Py2 is dead use:
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
""" Configuration base class and utilities."""
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import copy
|
||||
import json
|
||||
@@ -25,8 +24,15 @@ from io import open
|
||||
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \
|
||||
cached_path, is_remote_url, hf_bucket_url
|
||||
from .file_utils import (
|
||||
CONFIG_NAME,
|
||||
MODEL_CARD_NAME,
|
||||
WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
cached_path,
|
||||
is_remote_url,
|
||||
hf_bucket_url,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,17 +54,18 @@ class ModelCard(object):
|
||||
|
||||
Parameters:
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
|
||||
self.model_details = kwargs.pop('model_details', {})
|
||||
self.intended_use = kwargs.pop('intended_use', {})
|
||||
self.factors = kwargs.pop('factors', {})
|
||||
self.metrics = kwargs.pop('metrics', {})
|
||||
self.evaluation_data = kwargs.pop('evaluation_data', {})
|
||||
self.training_data = kwargs.pop('training_data', {})
|
||||
self.quantitative_analyses = kwargs.pop('quantitative_analyses', {})
|
||||
self.ethical_considerations = kwargs.pop('ethical_considerations', {})
|
||||
self.caveats_and_recommendations = kwargs.pop('caveats_and_recommendations', {})
|
||||
self.model_details = kwargs.pop("model_details", {})
|
||||
self.intended_use = kwargs.pop("intended_use", {})
|
||||
self.factors = kwargs.pop("factors", {})
|
||||
self.metrics = kwargs.pop("metrics", {})
|
||||
self.evaluation_data = kwargs.pop("evaluation_data", {})
|
||||
self.training_data = kwargs.pop("training_data", {})
|
||||
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
|
||||
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
|
||||
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
|
||||
|
||||
# Open additional attributes
|
||||
for key, value in kwargs.items():
|
||||
@@ -122,10 +129,10 @@ class ModelCard(object):
|
||||
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
find_from_standard_name = kwargs.pop('find_from_standard_name', True)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
# For simplicity we use the same pretrained url than the configuration files
|
||||
@@ -145,36 +152,43 @@ class ModelCard(object):
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True,
|
||||
proxies=proxies, resume_download=False)
|
||||
resolved_model_card_file = cached_path(
|
||||
model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
|
||||
)
|
||||
if resolved_model_card_file == model_card_file:
|
||||
logger.info("loading model card file {}".format(model_card_file))
|
||||
else:
|
||||
logger.info("loading model card file {} from cache at {}".format(
|
||||
model_card_file, resolved_model_card_file))
|
||||
logger.info(
|
||||
"loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
|
||||
)
|
||||
# Load model card
|
||||
modelcard = cls.from_json_file(resolved_model_card_file)
|
||||
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
logger.warning("Couldn't reach server at '{}' to download model card file.".format(
|
||||
model_card_file))
|
||||
logger.warning("Couldn't reach server at '{}' to download model card file.".format(model_card_file))
|
||||
else:
|
||||
logger.warning("Model name '{}' was not found in model name list ({}). " \
|
||||
"We assumed '{}' was a path or url to a model card file named {} or " \
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
logger.warning(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url to a model card file named {} or "
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
|
||||
model_card_file, MODEL_CARD_NAME))
|
||||
", ".join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
|
||||
model_card_file,
|
||||
MODEL_CARD_NAME,
|
||||
)
|
||||
)
|
||||
logger.warning("Creating an empty model card.")
|
||||
|
||||
# We fall back on creating an empty model card
|
||||
modelcard = cls()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Couldn't reach server at '{}' to download model card file or "
|
||||
"model card file is not a valid JSON file. "
|
||||
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file))
|
||||
logger.warning(
|
||||
"Couldn't reach server at '{}' to download model card file or "
|
||||
"model card file is not a valid JSON file. "
|
||||
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file)
|
||||
)
|
||||
logger.warning("Creating an empty model card.")
|
||||
|
||||
# We fall back on creating an empty model card
|
||||
@@ -203,7 +217,7 @@ class ModelCard(object):
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `ModelCard` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
dict_obj = json.loads(text)
|
||||
return cls(**dict_obj)
|
||||
@@ -225,5 +239,5 @@ class ModelCard(object):
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
||||
#
|
||||
@@ -30,14 +29,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
||||
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
||||
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
||||
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
||||
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
||||
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
||||
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
||||
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
||||
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
||||
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
@@ -48,8 +47,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
@@ -109,7 +110,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
if "seq_relationship" in name:
|
||||
continue
|
||||
|
||||
name = name.split('/')
|
||||
name = name.split("/")
|
||||
|
||||
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
||||
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
||||
@@ -118,19 +119,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||
l = re.split(r"_(\d+)", m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
if l[0] == "kernel" or l[0] == "gamma":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "output_bias" or l[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif l[0] == "output_weights":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
@@ -141,9 +142,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
if m_name[-11:] == "_embeddings":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "kernel":
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
@@ -160,6 +161,7 @@ class AlbertEmbeddings(BertEmbeddings):
|
||||
"""
|
||||
Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertEmbeddings, self).__init__(config)
|
||||
|
||||
@@ -238,9 +240,12 @@ class AlbertAttention(BertSelfAttention):
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
reshaped_context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
|
||||
# Should find a better way to do this
|
||||
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)
|
||||
w = (
|
||||
self.dense.weight.t()
|
||||
.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
|
||||
.to(context_layer.dtype)
|
||||
)
|
||||
b = self.dense.bias.to(context_layer.dtype)
|
||||
|
||||
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
|
||||
@@ -328,7 +333,11 @@ class AlbertTransformer(nn.Module):
|
||||
# Index of the layer inside the group
|
||||
layer_idx = int(i - group_idx * layers_per_group)
|
||||
|
||||
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
|
||||
layer_group_output = self.albert_layer_groups[group_idx](
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||
)
|
||||
hidden_states = layer_group_output[0]
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -337,7 +346,6 @@ class AlbertTransformer(nn.Module):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -346,11 +354,11 @@ class AlbertTransformer(nn.Module):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
|
||||
class AlbertPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = AlbertConfig
|
||||
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "albert"
|
||||
@@ -431,8 +439,12 @@ ALBERT_INPUTS_DOCSTRING = r"""
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ALBERT_START_DOCSTRING,
|
||||
ALBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class AlbertModel(AlbertPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -500,8 +512,15 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
||||
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
inputs_embeds=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
@@ -520,31 +539,37 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = (
|
||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||
) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
||||
|
||||
outputs = (sequence_output, pooled_output) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
outputs = (sequence_output, pooled_output) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs
|
||||
|
||||
|
||||
class AlbertMLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(AlbertMLMHead, self).__init__()
|
||||
@@ -566,7 +591,9 @@ class AlbertMLMHead(nn.Module):
|
||||
return prediction_scores
|
||||
|
||||
|
||||
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING
|
||||
)
|
||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -602,21 +629,28 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.predictions.decoder,
|
||||
self.albert.embeddings.word_embeddings)
|
||||
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.predictions.decoder
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
):
|
||||
outputs = self.albert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
sequence_outputs = outputs[0]
|
||||
|
||||
@@ -631,9 +665,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
||||
ALBERT_START_DOCSTRING,
|
||||
ALBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -665,6 +702,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -675,8 +713,16 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.albert(
|
||||
input_ids=input_ids,
|
||||
@@ -684,7 +730,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -707,10 +753,12 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
|
||||
@add_start_docstrings("""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings(
|
||||
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
||||
ALBERT_START_DOCSTRING,
|
||||
ALBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -752,6 +800,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(AlbertForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -761,8 +810,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
inputs_embeds=None, start_positions=None, end_positions=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
):
|
||||
|
||||
outputs = self.albert(
|
||||
input_ids=input_ids,
|
||||
@@ -770,7 +828,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -18,31 +18,87 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig,
|
||||
DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig,
|
||||
TransfoXLConfig, XLMConfig, XLNetConfig, XLMRobertaConfig)
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
RobertaConfig,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLNetConfig,
|
||||
XLMRobertaConfig,
|
||||
)
|
||||
|
||||
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \
|
||||
BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_bert import (
|
||||
BertModel,
|
||||
BertForMaskedLM,
|
||||
BertForSequenceClassification,
|
||||
BertForQuestionAnswering,
|
||||
BertForTokenClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \
|
||||
XLNetForTokenClassification, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering, \
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, \
|
||||
RobertaForTokenClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, \
|
||||
DistilBertForSequenceClassification, DistilBertForTokenClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \
|
||||
CamembertForMultipleChoice, CamembertForTokenClassification, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \
|
||||
AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_xlnet import (
|
||||
XLNetModel,
|
||||
XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetForTokenClassification,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_xlm import (
|
||||
XLMModel,
|
||||
XLMWithLMHeadModel,
|
||||
XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_roberta import (
|
||||
RobertaModel,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_distilbert import (
|
||||
DistilBertModel,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertForTokenClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_camembert import (
|
||||
CamembertModel,
|
||||
CamembertForMaskedLM,
|
||||
CamembertForSequenceClassification,
|
||||
CamembertForMultipleChoice,
|
||||
CamembertForTokenClassification,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_albert import (
|
||||
AlbertModel,
|
||||
AlbertForMaskedLM,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertForQuestionAnswering,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, \
|
||||
XLMRobertaForMultipleChoice, XLMRobertaForTokenClassification, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_xlm_roberta import (
|
||||
XLMRobertaModel,
|
||||
XLMRobertaForMaskedLM,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaForMultipleChoice,
|
||||
XLMRobertaForTokenClassification,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
@@ -51,7 +107,8 @@ from .file_utils import add_start_docstrings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -66,8 +123,9 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items())
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
|
||||
class AutoModel(object):
|
||||
@@ -98,10 +156,13 @@ class AutoModel(object):
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError("AutoModel is designed to be instantiated "
|
||||
raise EnvironmentError(
|
||||
"AutoModel is designed to be instantiated "
|
||||
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModel.from_config(config)` methods.")
|
||||
"`AutoModel.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
@@ -232,35 +293,39 @@ class AutoModel(object):
|
||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'camembert' in pretrained_model_name_or_path:
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'gpt2' in pretrained_model_name_or_path:
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'ctrl' in pretrained_model_name_or_path:
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelWithLMHead(object):
|
||||
@@ -291,10 +356,13 @@ class AutoModelWithLMHead(object):
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated "
|
||||
raise EnvironmentError(
|
||||
"AutoModelWithLMHead is designed to be instantiated "
|
||||
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelWithLMHead.from_config(config)` methods.")
|
||||
"`AutoModelWithLMHead.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
@@ -423,35 +491,39 @@ class AutoModelWithLMHead(object):
|
||||
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 't5' in pretrained_model_name_or_path:
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'camembert' in pretrained_model_name_or_path:
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'gpt2' in pretrained_model_name_or_path:
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'ctrl' in pretrained_model_name_or_path:
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForSequenceClassification(object):
|
||||
@@ -477,10 +549,13 @@ class AutoModelForSequenceClassification(object):
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError("AutoModelForSequenceClassification is designed to be instantiated "
|
||||
raise EnvironmentError(
|
||||
"AutoModelForSequenceClassification is designed to be instantiated "
|
||||
"using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForSequenceClassification.from_config(config)` methods.")
|
||||
"`AutoModelForSequenceClassification.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
@@ -597,25 +672,39 @@ class AutoModelForSequenceClassification(object):
|
||||
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
return AlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'camembert' in pretrained_model_name_or_path:
|
||||
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
||||
return XLMRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
if "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForQuestionAnswering(object):
|
||||
@@ -638,10 +727,13 @@ class AutoModelForQuestionAnswering(object):
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError("AutoModelForQuestionAnswering is designed to be instantiated "
|
||||
raise EnvironmentError(
|
||||
"AutoModelForQuestionAnswering is designed to be instantiated "
|
||||
"using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForQuestionAnswering.from_config(config)` methods.")
|
||||
"`AutoModelForQuestionAnswering.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
@@ -745,26 +837,30 @@ class AutoModelForQuestionAnswering(object):
|
||||
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
if "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForTokenClassification:
|
||||
def __init__(self):
|
||||
raise EnvironmentError("AutoModelForTokenClassification is designed to be instantiated "
|
||||
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForTokenClassification.from_config(config)` methods.")
|
||||
raise EnvironmentError(
|
||||
"AutoModelForTokenClassification is designed to be instantiated "
|
||||
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForTokenClassification.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
@@ -870,18 +966,28 @@ class AutoModelForTokenClassification:
|
||||
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if 'camembert' in pretrained_model_name_or_path:
|
||||
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
||||
return XLMRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
if "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(pretrained_model_name_or_path))
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
|
||||
pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,27 +33,27 @@ from .file_utils import add_start_docstrings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
||||
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
|
||||
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
|
||||
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
|
||||
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
|
||||
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
|
||||
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
|
||||
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
|
||||
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
|
||||
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
||||
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
||||
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
||||
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
||||
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
||||
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
||||
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
||||
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
||||
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
||||
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
|
||||
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
|
||||
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
|
||||
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
|
||||
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
|
||||
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
|
||||
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
|
||||
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
@@ -65,8 +65,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
@@ -81,7 +83,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
name = name.split("/")
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
@@ -89,18 +91,18 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||
l = re.split(r"_(\d+)", m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
if l[0] == "kernel" or l[0] == "gamma":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "output_bias" or l[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif l[0] == "output_weights":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif l[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
@@ -110,9 +112,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
if m_name[-11:] == "_embeddings":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "kernel":
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
@@ -157,6 +159,7 @@ BertLayerNorm = torch.nn.LayerNorm
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
||||
@@ -199,7 +202,8 @@ class BertSelfAttention(nn.Module):
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
@@ -217,7 +221,14 @@ class BertSelfAttention(nn.Module):
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
@@ -307,8 +318,17 @@ class BertAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
self_outputs = self.self(
|
||||
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
@@ -353,13 +373,22 @@ class BertLayer(nn.Module):
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
@@ -376,14 +405,23 @@ class BertEncoder(nn.Module):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -440,9 +478,7 @@ class BertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
@@ -488,6 +524,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = BertConfig
|
||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
@@ -581,8 +618,12 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertModel(BertPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -612,6 +653,7 @@ class BertModel(BertPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertModel, self).__init__(config)
|
||||
self.config = config
|
||||
@@ -636,8 +678,17 @@ class BertModel(BertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
""" Forward pass on the Model.
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
@@ -681,12 +732,18 @@ class BertModel(BertPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
|
||||
causal_mask = causal_mask.to(
|
||||
torch.long
|
||||
) # not converting to long will cause errors with pytorch version < 1.3
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -709,10 +766,15 @@ class BertModel(BertPreTrainedModel):
|
||||
elif encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
|
||||
encoder_attention_mask.shape))
|
||||
raise ValueError(
|
||||
"Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
|
||||
encoder_hidden_shape, encoder_attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
@@ -727,28 +789,40 @@ class BertModel(BertPreTrainedModel):
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = (
|
||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||
) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with two heads on top as done during the pre-training:
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -786,6 +860,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
prediction_scores, seq_relationship_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForPreTraining, self).__init__(config)
|
||||
|
||||
@@ -797,20 +872,33 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None, next_sentence_label=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
next_sentence_label=None,
|
||||
):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
||||
outputs = (prediction_scores, seq_relationship_score,) + outputs[
|
||||
2:
|
||||
] # add hidden states and attention if they are here
|
||||
|
||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
@@ -822,9 +910,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
|
||||
)
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -862,6 +950,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForMaskedLM, self).__init__(config)
|
||||
|
||||
@@ -873,17 +962,30 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
lm_labels=None,
|
||||
):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
@@ -912,9 +1014,11 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
r"""
|
||||
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -945,6 +1049,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
seq_relationship_scores = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForNextSentencePrediction, self).__init__(config)
|
||||
|
||||
@@ -953,15 +1058,25 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
next_sentence_label=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
next_sentence_label=None,
|
||||
):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -976,10 +1091,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -1011,6 +1128,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -1021,15 +1139,25 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -1051,10 +1179,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -1087,6 +1217,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForMultipleChoice, self).__init__(config)
|
||||
|
||||
@@ -1096,8 +1227,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
@@ -1105,12 +1244,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -1128,10 +1269,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -1161,6 +1304,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -1171,15 +1315,25 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -1202,10 +1356,12 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -1247,6 +1403,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertForQuestionAnswering, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -1256,15 +1413,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||
start_positions=None, end_positions=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
@@ -15,19 +15,24 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch CamemBERT model. """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
|
||||
from .modeling_roberta import (
|
||||
RobertaModel,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaForTokenClassification,
|
||||
)
|
||||
from .configuration_camembert import CamembertConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-pytorch_model.bin",
|
||||
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
@@ -100,8 +105,12 @@ CAMEMBERT_INPUTS_DOCSTRING = r"""
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
CAMEMBERT_START_DOCSTRING,
|
||||
CAMEMBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class CamembertModel(RobertaModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -149,8 +158,11 @@ class CamembertModel(RobertaModel):
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings("""CamemBERT Model with a `language modeling` head on top. """,
|
||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""CamemBERT Model with a `language modeling` head on top. """,
|
||||
CAMEMBERT_START_DOCSTRING,
|
||||
CAMEMBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class CamembertForMaskedLM(RobertaForMaskedLM):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
@@ -185,9 +197,12 @@ class CamembertForMaskedLM(RobertaForMaskedLM):
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings("""CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer
|
||||
@add_start_docstrings(
|
||||
"""CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer
|
||||
on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
||||
CAMEMBERT_START_DOCSTRING,
|
||||
CAMEMBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
@@ -223,9 +238,12 @@ class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings("""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
||||
CAMEMBERT_START_DOCSTRING,
|
||||
CAMEMBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
@@ -257,9 +275,12 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings("""CamemBERT Model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings(
|
||||
"""CamemBERT Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
||||
CAMEMBERT_START_DOCSTRING,
|
||||
CAMEMBERT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class CamembertForTokenClassification(RobertaForTokenClassification):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user