Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -247,7 +247,8 @@ the wall, slowly on into the Social Predestination Room.
|
|||||||
as they entered."""
|
as they entered."""
|
||||||
|
|
||||||
|
|
||||||
def create_setup_and_compute(model_names: List[str],
|
def create_setup_and_compute(
|
||||||
|
model_names: List[str],
|
||||||
gpu: bool = True,
|
gpu: bool = True,
|
||||||
tensorflow: bool = False,
|
tensorflow: bool = False,
|
||||||
average_over: int = 3,
|
average_over: int = 3,
|
||||||
@@ -256,7 +257,8 @@ def create_setup_and_compute(model_names: List[str],
|
|||||||
amp: bool = False,
|
amp: bool = False,
|
||||||
fp16: bool = False,
|
fp16: bool = False,
|
||||||
save_to_csv: bool = False,
|
save_to_csv: bool = False,
|
||||||
csv_filename: str = f"results_{round(time())}.csv"):
|
csv_filename: str = f"results_{round(time())}.csv",
|
||||||
|
):
|
||||||
if xla:
|
if xla:
|
||||||
tf.config.optimizer.set_jit(True)
|
tf.config.optimizer.set_jit(True)
|
||||||
if amp:
|
if amp:
|
||||||
@@ -266,7 +268,7 @@ def create_setup_and_compute(model_names: List[str],
|
|||||||
dictionary = {model_name: {} for model_name in model_names}
|
dictionary = {model_name: {} for model_name in model_names}
|
||||||
results = _compute_tensorflow(model_names, dictionary, average_over, amp)
|
results = _compute_tensorflow(model_names, dictionary, average_over, amp)
|
||||||
else:
|
else:
|
||||||
device = 'cuda' if (gpu and torch.cuda.is_available()) else 'cpu'
|
device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
|
||||||
dictionary = {model_name: {} for model_name in model_names}
|
dictionary = {model_name: {} for model_name in model_names}
|
||||||
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
|
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
|
||||||
|
|
||||||
@@ -276,22 +278,40 @@ def create_setup_and_compute(model_names: List[str],
|
|||||||
for batch_size in results[model_name]["bs"]:
|
for batch_size in results[model_name]["bs"]:
|
||||||
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
|
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
|
||||||
for slice_size in results[model_name]["ss"]:
|
for slice_size in results[model_name]["ss"]:
|
||||||
result = results[model_name]['results'][batch_size][slice_size]
|
result = results[model_name]["results"][batch_size][slice_size]
|
||||||
if isinstance(result, str):
|
if isinstance(result, str):
|
||||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: "
|
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
|
||||||
f"{result}")
|
|
||||||
else:
|
else:
|
||||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: "
|
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
|
||||||
f"{(round(1000 * result) / 1000)}"
|
|
||||||
f"s")
|
|
||||||
|
|
||||||
if save_to_csv:
|
if save_to_csv:
|
||||||
with open(csv_filename, mode='w') as csv_file:
|
with open(csv_filename, mode="w") as csv_file:
|
||||||
fieldnames = ['model',
|
fieldnames = [
|
||||||
'1x8', '1x64', '1x128', '1x256', '1x512', '1x1024',
|
"model",
|
||||||
'2x8', '2x64', '2x128', '2x256', '2x512', '2x1024',
|
"1x8",
|
||||||
'4x8', '4x64', '4x128', '4x256', '4x512', '4x1024',
|
"1x64",
|
||||||
'8x8', '8x64', '8x128', '8x256', '8x512', '8x1024',
|
"1x128",
|
||||||
|
"1x256",
|
||||||
|
"1x512",
|
||||||
|
"1x1024",
|
||||||
|
"2x8",
|
||||||
|
"2x64",
|
||||||
|
"2x128",
|
||||||
|
"2x256",
|
||||||
|
"2x512",
|
||||||
|
"2x1024",
|
||||||
|
"4x8",
|
||||||
|
"4x64",
|
||||||
|
"4x128",
|
||||||
|
"4x256",
|
||||||
|
"4x512",
|
||||||
|
"4x1024",
|
||||||
|
"8x8",
|
||||||
|
"8x64",
|
||||||
|
"8x128",
|
||||||
|
"8x256",
|
||||||
|
"8x512",
|
||||||
|
"8x1024",
|
||||||
]
|
]
|
||||||
|
|
||||||
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
||||||
@@ -299,11 +319,11 @@ def create_setup_and_compute(model_names: List[str],
|
|||||||
|
|
||||||
for model_name in model_names:
|
for model_name in model_names:
|
||||||
model_results = {
|
model_results = {
|
||||||
f'{bs}x{ss}': results[model_name]['results'][bs][ss]
|
f"{bs}x{ss}": results[model_name]["results"][bs][ss]
|
||||||
for bs in results[model_name]["results"]
|
for bs in results[model_name]["results"]
|
||||||
for ss in results[model_name]['results'][bs]
|
for ss in results[model_name]["results"][bs]
|
||||||
}
|
}
|
||||||
writer.writerow({'model': model_name, **model_results})
|
writer.writerow({"model": model_name, **model_results})
|
||||||
|
|
||||||
|
|
||||||
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
|
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
|
||||||
@@ -379,7 +399,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
|||||||
if max_input_size is not None and slice_size > max_input_size:
|
if max_input_size is not None and slice_size > max_input_size:
|
||||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||||
else:
|
else:
|
||||||
sequence = tf.stack([tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size)
|
sequence = tf.stack(
|
||||||
|
[tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Going through model with sequence of shape", sequence.shape)
|
print("Going through model with sequence of shape", sequence.shape)
|
||||||
@@ -399,33 +421,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--models", required=False, type=str, default='all', help="Model checkpoints to be provided "
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help="Model checkpoints to be provided "
|
||||||
"to the AutoModel classes. Leave "
|
"to the AutoModel classes. Leave "
|
||||||
"blank to benchmark the base version "
|
"blank to benchmark the base version "
|
||||||
"of all available model "
|
"of all available model "
|
||||||
"architectures.")
|
"architectures.",
|
||||||
parser.add_argument("--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the "
|
)
|
||||||
"models")
|
parser.add_argument(
|
||||||
parser.add_argument("--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available "
|
"--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
|
||||||
"cuda devices")
|
)
|
||||||
parser.add_argument("--torchscript", required=False, action="store_true", help="Pytorch only: trace the models "
|
parser.add_argument(
|
||||||
"using torchscript")
|
"--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "cuda devices"
|
||||||
parser.add_argument("--tensorflow", required=False, action="store_true", help="Benchmark the TensorFlow version "
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--torchscript",
|
||||||
|
required=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Pytorch only: trace the models " "using torchscript",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorflow",
|
||||||
|
required=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Benchmark the TensorFlow version "
|
||||||
"of the models. Will run on GPU if "
|
"of the models. Will run on GPU if "
|
||||||
"the correct dependencies are "
|
"the correct dependencies are "
|
||||||
"installed")
|
"installed",
|
||||||
|
)
|
||||||
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
|
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
|
||||||
parser.add_argument("--amp", required=False, action="store_true", help="TensorFlow only: use automatic mixed precision acceleration.")
|
parser.add_argument(
|
||||||
parser.add_argument("--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference.")
|
"--amp",
|
||||||
parser.add_argument("--keras_predict", required=False, action="store_true", help="Whether to use model.predict "
|
required=False,
|
||||||
"instead of model() to do a "
|
action="store_true",
|
||||||
"forward pass.")
|
help="TensorFlow only: use automatic mixed precision acceleration.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--keras_predict",
|
||||||
|
required=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use model.predict " "instead of model() to do a " "forward pass.",
|
||||||
|
)
|
||||||
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
|
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
|
||||||
parser.add_argument("--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv.")
|
parser.add_argument(
|
||||||
parser.add_argument("--average_over", required=False, default=30, type=int, help="Times an experiment will be run.")
|
"--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.models == 'all':
|
if args.models == "all":
|
||||||
args.models = [
|
args.models = [
|
||||||
"gpt2",
|
"gpt2",
|
||||||
"bert-base-cased",
|
"bert-base-cased",
|
||||||
@@ -436,7 +489,7 @@ def main():
|
|||||||
"distilbert-base-uncased",
|
"distilbert-base-uncased",
|
||||||
"distilgpt2",
|
"distilgpt2",
|
||||||
"roberta-base",
|
"roberta-base",
|
||||||
"ctrl"
|
"ctrl",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
args.models = args.models.split()
|
args.models = args.models.split()
|
||||||
@@ -453,7 +506,7 @@ def main():
|
|||||||
fp16=args.fp16,
|
fp16=args.fp16,
|
||||||
save_to_csv=args.save_to_csv,
|
save_to_csv=args.save_to_csv,
|
||||||
csv_filename=args.csv_filename,
|
csv_filename=args.csv_filename,
|
||||||
average_over=args.average_over
|
average_over=args.average_over,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
|
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
|
||||||
@@ -467,11 +520,11 @@ def main():
|
|||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
save_to_csv=args.save_to_csv,
|
save_to_csv=args.save_to_csv,
|
||||||
csv_filename=args.csv_filename,
|
csv_filename=args.csv_filename,
|
||||||
average_over=args.average_over
|
average_over=args.average_over,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
|
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -10,38 +10,37 @@ from transformers.modeling_camembert import CamembertForMaskedLM
|
|||||||
|
|
||||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||||
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
|
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
|
||||||
assert masked_input.count('<mask>') == 1
|
assert masked_input.count("<mask>") == 1
|
||||||
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||||
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
|
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
|
||||||
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
|
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
|
||||||
logits = logits[0, masked_index, :]
|
logits = logits[0, masked_index, :]
|
||||||
prob = logits.softmax(dim=0)
|
prob = logits.softmax(dim=0)
|
||||||
values, indices = prob.topk(k=topk, dim=0)
|
values, indices = prob.topk(k=topk, dim=0)
|
||||||
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item())
|
topk_predicted_token_bpe = " ".join(
|
||||||
for i in range(len(indices))])
|
[tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
|
||||||
|
)
|
||||||
masked_token = tokenizer.mask_token
|
masked_token = tokenizer.mask_token
|
||||||
topk_filled_outputs = []
|
topk_filled_outputs = []
|
||||||
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
|
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
|
||||||
predicted_token = predicted_token_bpe.replace('\u2581', ' ')
|
predicted_token = predicted_token_bpe.replace("\u2581", " ")
|
||||||
if " {0}".format(masked_token) in masked_input:
|
if " {0}".format(masked_token) in masked_input:
|
||||||
topk_filled_outputs.append((
|
topk_filled_outputs.append(
|
||||||
masked_input.replace(
|
(
|
||||||
' {0}'.format(masked_token), predicted_token
|
masked_input.replace(" {0}".format(masked_token), predicted_token),
|
||||||
),
|
|
||||||
values[index].item(),
|
values[index].item(),
|
||||||
predicted_token,
|
predicted_token,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
topk_filled_outputs.append((
|
topk_filled_outputs.append(
|
||||||
masked_input.replace(masked_token, predicted_token),
|
(masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token,)
|
||||||
values[index].item(),
|
)
|
||||||
predicted_token,
|
|
||||||
))
|
|
||||||
return topk_filled_outputs
|
return topk_filled_outputs
|
||||||
|
|
||||||
|
|
||||||
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
|
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
|
||||||
model = CamembertForMaskedLM.from_pretrained('camembert-base')
|
model = CamembertForMaskedLM.from_pretrained("camembert-base")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
masked_input = "Le camembert est <mask> :)"
|
masked_input = "Le camembert est <mask> :)"
|
||||||
|
|||||||
@@ -36,34 +36,42 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
|
|
||||||
from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
|
from transformers import (
|
||||||
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME,
|
OpenAIGPTDoubleHeadsModel,
|
||||||
get_linear_schedule_with_warmup)
|
OpenAIGPTTokenizer,
|
||||||
|
AdamW,
|
||||||
|
cached_path,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
CONFIG_NAME,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
level = logging.INFO)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def accuracy(out, labels):
|
def accuracy(out, labels):
|
||||||
outputs = np.argmax(out, axis=1)
|
outputs = np.argmax(out, axis=1)
|
||||||
return np.sum(outputs == labels)
|
return np.sum(outputs == labels)
|
||||||
|
|
||||||
|
|
||||||
def load_rocstories_dataset(dataset_path):
|
def load_rocstories_dataset(dataset_path):
|
||||||
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
|
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
|
||||||
with open(dataset_path, encoding='utf_8') as f:
|
with open(dataset_path, encoding="utf_8") as f:
|
||||||
f = csv.reader(f)
|
f = csv.reader(f)
|
||||||
output = []
|
output = []
|
||||||
next(f) # skip the first line
|
next(f) # skip the first line
|
||||||
for line in tqdm(f):
|
for line in tqdm(f):
|
||||||
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1))
|
output.append((" ".join(line[1:5]), line[5], line[6], int(line[-1]) - 1))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
|
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
|
||||||
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
|
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
|
||||||
|
|
||||||
@@ -91,45 +99,57 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
|
|||||||
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
|
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
|
||||||
return tensor_datasets
|
return tensor_datasets
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_name', type=str, default='openai-gpt',
|
parser.add_argument("--model_name", type=str, default="openai-gpt", help="pretrained model name")
|
||||||
help='pretrained model name')
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
"--output_dir",
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
default=None,
|
||||||
parser.add_argument('--train_dataset', type=str, default='')
|
type=str,
|
||||||
parser.add_argument('--eval_dataset', type=str, default='')
|
required=True,
|
||||||
parser.add_argument('--seed', type=int, default=42)
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
parser.add_argument('--num_train_epochs', type=int, default=3)
|
)
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument("--train_dataset", type=str, default="")
|
||||||
parser.add_argument('--eval_batch_size', type=int, default=16)
|
parser.add_argument("--eval_dataset", type=str, default="")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--num_train_epochs", type=int, default=3)
|
||||||
parser.add_argument('--max_grad_norm', type=int, default=1)
|
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||||
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
|
parser.add_argument("--max_grad_norm", type=int, default=1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_steps",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
help="If > 0: set total number of training \
|
help="If > 0: set total number of training \
|
||||||
steps to perform. Override num_train_epochs.")
|
steps to perform. Override num_train_epochs.",
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help="Number of updates steps to accumulate before\
|
help="Number of updates steps to accumulate before\
|
||||||
performing a backward/update pass.")
|
performing a backward/update pass.",
|
||||||
parser.add_argument('--learning_rate', type=float, default=6.25e-5)
|
)
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
|
||||||
help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
parser.add_argument("--lr_schedule", type=str, default="warmup_linear")
|
||||||
parser.add_argument('--weight_decay', type=float, default=0.01)
|
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||||||
parser.add_argument('--lm_coef', type=float, default=0.9)
|
parser.add_argument("--lm_coef", type=float, default=0.9)
|
||||||
parser.add_argument('--n_valid', type=int, default=374)
|
parser.add_argument("--n_valid", type=int, default=374)
|
||||||
|
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -152,7 +172,7 @@ def main():
|
|||||||
# Load tokenizer and model
|
# Load tokenizer and model
|
||||||
# This loading functions also add new tokens and embeddings called `special tokens`
|
# This loading functions also add new tokens and embeddings called `special tokens`
|
||||||
# These new embeddings will be fine-tuned on the RocStories dataset
|
# These new embeddings will be fine-tuned on the RocStories dataset
|
||||||
special_tokens = ['_start_', '_delimiter_', '_classify_']
|
special_tokens = ["_start_", "_delimiter_", "_classify_"]
|
||||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
|
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
|
||||||
tokenizer.add_tokens(special_tokens)
|
tokenizer.add_tokens(special_tokens)
|
||||||
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
|
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
|
||||||
@@ -163,6 +183,7 @@ def main():
|
|||||||
# Load and encode the datasets
|
# Load and encode the datasets
|
||||||
if not args.train_dataset and not args.eval_dataset:
|
if not args.train_dataset and not args.eval_dataset:
|
||||||
roc_stories = cached_path(ROCSTORIES_URL)
|
roc_stories = cached_path(ROCSTORIES_URL)
|
||||||
|
|
||||||
def tokenize_and_encode(obj):
|
def tokenize_and_encode(obj):
|
||||||
""" Tokenize and encode a nested object """
|
""" Tokenize and encode a nested object """
|
||||||
if isinstance(obj, str):
|
if isinstance(obj, str):
|
||||||
@@ -170,6 +191,7 @@ def main():
|
|||||||
elif isinstance(obj, int):
|
elif isinstance(obj, int):
|
||||||
return obj
|
return obj
|
||||||
return list(tokenize_and_encode(o) for o in obj)
|
return list(tokenize_and_encode(o) for o in obj)
|
||||||
|
|
||||||
logger.info("Encoding dataset...")
|
logger.info("Encoding dataset...")
|
||||||
train_dataset = load_rocstories_dataset(args.train_dataset)
|
train_dataset = load_rocstories_dataset(args.train_dataset)
|
||||||
eval_dataset = load_rocstories_dataset(args.eval_dataset)
|
eval_dataset = load_rocstories_dataset(args.eval_dataset)
|
||||||
@@ -178,8 +200,11 @@ def main():
|
|||||||
|
|
||||||
# Compute the max input length for the Transformer
|
# Compute the max input length for the Transformer
|
||||||
max_length = model.config.n_positions // 2 - 2
|
max_length = model.config.n_positions // 2 - 2
|
||||||
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \
|
input_length = max(
|
||||||
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
|
len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3
|
||||||
|
for dataset in encoded_datasets
|
||||||
|
for story, cont1, cont2, _ in dataset
|
||||||
|
)
|
||||||
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
|
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
|
||||||
|
|
||||||
# Prepare inputs tensors and dataloaders
|
# Prepare inputs tensors and dataloaders
|
||||||
@@ -198,20 +223,23 @@ def main():
|
|||||||
if args.do_train:
|
if args.do_train:
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
t_total = args.max_steps
|
t_total = args.max_steps
|
||||||
args.num_train_epochs = args.max_steps //\
|
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||||
(len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
|
||||||
else:
|
else:
|
||||||
t_total = len(train_dataloader)\
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
// args.gradient_accumulation_steps * args.num_train_epochs
|
|
||||||
|
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
|
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
|
||||||
@@ -230,14 +258,16 @@ def main():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item()
|
exp_average_loss = (
|
||||||
|
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
|
||||||
|
)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
|
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
|
||||||
|
|
||||||
# Save a trained model
|
# Save a trained model
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
# Save a trained model, configuration and tokenizer
|
# Save a trained model, configuration and tokenizer
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself
|
model_to_save = model.module if hasattr(model, "module") else model # Only save the model itself
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
@@ -260,10 +290,12 @@ def main():
|
|||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
input_ids, mc_token_ids, lm_labels, mc_labels = batch
|
input_ids, mc_token_ids, lm_labels, mc_labels = batch
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels)
|
_, mc_loss, _, mc_logits = model(
|
||||||
|
input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels
|
||||||
|
)
|
||||||
|
|
||||||
mc_logits = mc_logits.detach().cpu().numpy()
|
mc_logits = mc_logits.detach().cpu().numpy()
|
||||||
mc_labels = mc_labels.to('cpu').numpy()
|
mc_labels = mc_labels.to("cpu").numpy()
|
||||||
tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
|
tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
|
||||||
|
|
||||||
eval_loss += mc_loss.mean().item()
|
eval_loss += mc_loss.mean().item()
|
||||||
@@ -275,9 +307,7 @@ def main():
|
|||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||||
train_loss = tr_loss / nb_tr_steps if args.do_train else None
|
train_loss = tr_loss / nb_tr_steps if args.do_train else None
|
||||||
result = {'eval_loss': eval_loss,
|
result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_loss": train_loss}
|
||||||
'eval_accuracy': eval_accuracy,
|
|
||||||
'train_loss': train_loss}
|
|
||||||
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
@@ -286,5 +316,6 @@ def main():
|
|||||||
logger.info(" %s = %s", key, str(result[key]))
|
logger.info(" %s = %s", key, str(result[key]))
|
||||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -28,8 +28,7 @@ import glob
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -39,31 +38,23 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import WEIGHTS_NAME, BertConfig, BertForMultipleChoice, BertTokenizer
|
||||||
BertForMultipleChoice, BertTokenizer)
|
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
|
||||||
for conf in [BertConfig]), ())
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SwagExample(object):
|
class SwagExample(object):
|
||||||
"""A single training/test example for the SWAG dataset."""
|
"""A single training/test example for the SWAG dataset."""
|
||||||
def __init__(self,
|
|
||||||
swag_id,
|
def __init__(self, swag_id, context_sentence, start_ending, ending_0, ending_1, ending_2, ending_3, label=None):
|
||||||
context_sentence,
|
|
||||||
start_ending,
|
|
||||||
ending_0,
|
|
||||||
ending_1,
|
|
||||||
ending_2,
|
|
||||||
ending_3,
|
|
||||||
label = None):
|
|
||||||
self.swag_id = swag_id
|
self.swag_id = swag_id
|
||||||
self.context_sentence = context_sentence
|
self.context_sentence = context_sentence
|
||||||
self.start_ending = start_ending
|
self.start_ending = start_ending
|
||||||
@@ -94,37 +85,28 @@ class SwagExample(object):
|
|||||||
|
|
||||||
return ", ".join(l)
|
return ", ".join(l)
|
||||||
|
|
||||||
class InputFeatures(object):
|
|
||||||
def __init__(self,
|
|
||||||
example_id,
|
|
||||||
choices_features,
|
|
||||||
label
|
|
||||||
|
|
||||||
):
|
class InputFeatures(object):
|
||||||
|
def __init__(self, example_id, choices_features, label):
|
||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
self.choices_features = [
|
self.choices_features = [
|
||||||
{
|
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
|
||||||
'input_ids': input_ids,
|
|
||||||
'input_mask': input_mask,
|
|
||||||
'segment_ids': segment_ids
|
|
||||||
}
|
|
||||||
for _, input_ids, input_mask, segment_ids in choices_features
|
for _, input_ids, input_mask, segment_ids in choices_features
|
||||||
]
|
]
|
||||||
self.label = label
|
self.label = label
|
||||||
|
|
||||||
|
|
||||||
def read_swag_examples(input_file, is_training=True):
|
def read_swag_examples(input_file, is_training=True):
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
line = list(unicode(cell, "utf-8") for cell in line)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
|
|
||||||
if is_training and lines[0][-1] != 'label':
|
if is_training and lines[0][-1] != "label":
|
||||||
raise ValueError(
|
raise ValueError("For training, the input file must contain a label column.")
|
||||||
"For training, the input file must contain a label column."
|
|
||||||
)
|
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
SwagExample(
|
SwagExample(
|
||||||
@@ -137,14 +119,15 @@ def read_swag_examples(input_file, is_training=True):
|
|||||||
ending_1=line[8],
|
ending_1=line[8],
|
||||||
ending_2=line[9],
|
ending_2=line[9],
|
||||||
ending_3=line[10],
|
ending_3=line[10],
|
||||||
label = int(line[11]) if is_training else None
|
label=int(line[11]) if is_training else None,
|
||||||
) for line in lines[1:] # we skip the line with the column names
|
)
|
||||||
|
for line in lines[1:] # we skip the line with the column names
|
||||||
]
|
]
|
||||||
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|
||||||
is_training):
|
def convert_examples_to_features(examples, tokenizer, max_seq_length, is_training):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
# Swag is a multiple choice task. To perform this task using Bert,
|
# Swag is a multiple choice task. To perform this task using Bert,
|
||||||
@@ -204,23 +187,18 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
logger.info("swag_id: {}".format(example.swag_id))
|
logger.info("swag_id: {}".format(example.swag_id))
|
||||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||||
logger.info("choice: {}".format(choice_idx))
|
logger.info("choice: {}".format(choice_idx))
|
||||||
logger.info("tokens: {}".format(' '.join(tokens)))
|
logger.info("tokens: {}".format(" ".join(tokens)))
|
||||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
||||||
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
|
logger.info("input_mask: {}".format(" ".join(map(str, input_mask))))
|
||||||
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
|
logger.info("segment_ids: {}".format(" ".join(map(str, segment_ids))))
|
||||||
if is_training:
|
if is_training:
|
||||||
logger.info("label: {}".format(label))
|
logger.info("label: {}".format(label))
|
||||||
|
|
||||||
features.append(
|
features.append(InputFeatures(example_id=example.swag_id, choices_features=choices_features, label=label))
|
||||||
InputFeatures(
|
|
||||||
example_id = example.swag_id,
|
|
||||||
choices_features = choices_features,
|
|
||||||
label = label
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
"""Truncates a sequence pair in place to the maximum length."""
|
"""Truncates a sequence pair in place to the maximum length."""
|
||||||
|
|
||||||
@@ -237,18 +215,14 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||||||
else:
|
else:
|
||||||
tokens_b.pop()
|
tokens_b.pop()
|
||||||
|
|
||||||
|
|
||||||
def accuracy(out, labels):
|
def accuracy(out, labels):
|
||||||
outputs = np.argmax(out, axis=1)
|
outputs = np.argmax(out, axis=1)
|
||||||
return np.sum(outputs == labels)
|
return np.sum(outputs == labels)
|
||||||
|
|
||||||
|
|
||||||
def select_field(features, field):
|
def select_field(features, field):
|
||||||
return [
|
return [[choice[field] for choice in feature.choices_features] for feature in features]
|
||||||
[
|
|
||||||
choice[field]
|
|
||||||
for choice in feature.choices_features
|
|
||||||
]
|
|
||||||
for feature in features
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
@@ -258,24 +232,28 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
input_file = args.predict_file if evaluate else args.train_file
|
input_file = args.predict_file if evaluate else args.train_file
|
||||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
'dev' if evaluate else 'train',
|
os.path.dirname(input_file),
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
"cached_{}_{}_{}".format(
|
||||||
str(args.max_seq_length)))
|
"dev" if evaluate else "train",
|
||||||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
|
str(args.max_seq_length),
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_file)
|
logger.info("Creating features from dataset file at %s", input_file)
|
||||||
examples = read_swag_examples(input_file)
|
examples = read_swag_examples(input_file)
|
||||||
features = convert_examples_to_features(
|
features = convert_examples_to_features(examples, tokenizer, args.max_seq_length, not evaluate)
|
||||||
examples, tokenizer, args.max_seq_length, not evaluate)
|
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
@@ -285,21 +263,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
# Convert to Tensors and build dataset
|
# Convert to Tensors and build dataset
|
||||||
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
|
all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
|
||||||
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
|
all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
|
all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
|
||||||
all_label = torch.tensor([f.label for f in features], dtype=torch.long)
|
all_label = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||||
|
|
||||||
if evaluate:
|
if evaluate:
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||||
all_label)
|
|
||||||
else:
|
else:
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||||
all_label)
|
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
return dataset, examples, features
|
return dataset, examples, features
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@@ -316,13 +294,18 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -336,17 +319,21 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -360,11 +347,13 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1],
|
"input_ids": batch[0],
|
||||||
|
"attention_mask": batch[1],
|
||||||
#'token_type_ids': None if args.model_type == 'xlm' else batch[2],
|
#'token_type_ids': None if args.model_type == 'xlm' else batch[2],
|
||||||
'token_type_ids': batch[2],
|
"token_type_ids": batch[2],
|
||||||
'labels': batch[3]}
|
"labels": batch[3],
|
||||||
|
}
|
||||||
# if args.model_type in ['xlnet', 'xlm']:
|
# if args.model_type in ['xlnet', 'xlm']:
|
||||||
# inputs.update({'cls_index': batch[5],
|
# inputs.update({'cls_index': batch[5],
|
||||||
# 'p_mask': batch[6]})
|
# 'p_mask': batch[6]})
|
||||||
@@ -393,23 +382,27 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_vocabulary(output_dir)
|
tokenizer.save_vocabulary(output_dir)
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -424,6 +417,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix=""):
|
def evaluate(args, model, tokenizer, prefix=""):
|
||||||
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||||
|
|
||||||
@@ -440,7 +434,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
logger.info(" Num examples = %d", len(dataset))
|
logger.info(" Num examples = %d", len(dataset))
|
||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
|
||||||
|
|
||||||
eval_loss, eval_accuracy = 0, 0
|
eval_loss, eval_accuracy = 0, 0
|
||||||
nb_eval_steps, nb_eval_examples = 0, 0
|
nb_eval_steps, nb_eval_examples = 0, 0
|
||||||
|
|
||||||
@@ -448,11 +441,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
model.eval()
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1],
|
"input_ids": batch[0],
|
||||||
|
"attention_mask": batch[1],
|
||||||
# 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
# 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||||
'token_type_ids': batch[2],
|
"token_type_ids": batch[2],
|
||||||
'labels': batch[3]}
|
"labels": batch[3],
|
||||||
|
}
|
||||||
|
|
||||||
# if args.model_type in ['xlnet', 'xlm']:
|
# if args.model_type in ['xlnet', 'xlm']:
|
||||||
# inputs.update({'cls_index': batch[4],
|
# inputs.update({'cls_index': batch[4],
|
||||||
@@ -462,17 +457,16 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
eval_loss += tmp_eval_loss.mean().item()
|
eval_loss += tmp_eval_loss.mean().item()
|
||||||
|
|
||||||
logits = logits.detach().cpu().numpy()
|
logits = logits.detach().cpu().numpy()
|
||||||
label_ids = inputs['labels'].to('cpu').numpy()
|
label_ids = inputs["labels"].to("cpu").numpy()
|
||||||
tmp_eval_accuracy = accuracy(logits, label_ids)
|
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||||
eval_accuracy += tmp_eval_accuracy
|
eval_accuracy += tmp_eval_accuracy
|
||||||
|
|
||||||
nb_eval_steps += 1
|
nb_eval_steps += 1
|
||||||
nb_eval_examples += inputs['input_ids'].size(0)
|
nb_eval_examples += inputs["input_ids"].size(0)
|
||||||
|
|
||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||||
result = {'eval_loss': eval_loss,
|
result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy}
|
||||||
'eval_accuracy': eval_accuracy}
|
|
||||||
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
@@ -483,92 +477,144 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="SWAG csv for training. E.g., train.csv")
|
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
|
||||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
)
|
||||||
help="SWAG csv for predictions. E.g., val.csv or test.csv")
|
parser.add_argument(
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
"--predict_file",
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
default=None,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
type=str,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
required=True,
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
help="SWAG csv for predictions. E.g., val.csv or test.csv",
|
||||||
help="The output directory where the model checkpoints and predictions will be written.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
"--tokenizer_name",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=384,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences "
|
help="The maximum total input sequence length after tokenization. Sequences "
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
help="Rul evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="The initial learning rate for Adam.")
|
parser.add_argument(
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
"--gradient_accumulation_steps",
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
type=int,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
default=1,
|
||||||
help="Weight deay if we apply some.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
)
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Whether not to use CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||||
help="local_rank for distributed training on gpus")
|
parser.add_argument(
|
||||||
parser.add_argument('--fp16', action='store_true',
|
"--fp16",
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
action="store_true",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
)
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -580,16 +626,24 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -601,8 +655,12 @@ def main():
|
|||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
|
||||||
|
)
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -617,7 +675,6 @@ def main():
|
|||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Save the trained model and the tokenizer
|
# Save the trained model and the tokenizer
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -627,19 +684,20 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
@@ -650,14 +708,16 @@ def main():
|
|||||||
checkpoints = [args.model_name_or_path]
|
checkpoints = [args.model_name_or_path]
|
||||||
|
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||||
|
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
tokenizer = tokenizer_class.from_pretrained(checkpoint)
|
tokenizer = tokenizer_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
@@ -665,7 +725,7 @@ def main():
|
|||||||
# Evaluate
|
# Evaluate
|
||||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||||
|
|
||||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
logger.info("Results: {}".format(results))
|
logger.info("Results: {}".format(results))
|
||||||
|
|||||||
@@ -30,44 +30,36 @@ import torch
|
|||||||
|
|
||||||
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
|
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
level = logging.INFO)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
|
||||||
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
|
parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
|
||||||
help='pretrained model name')
|
parser.add_argument(
|
||||||
parser.add_argument('--split', type=str, default='test',
|
"--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
|
||||||
choices=['all', 'valid', 'test'],
|
)
|
||||||
help='which split to evaluate')
|
parser.add_argument("--batch_size", type=int, default=10, help="batch size")
|
||||||
parser.add_argument('--batch_size', type=int, default=10,
|
parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
|
||||||
help='batch size')
|
parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
|
||||||
parser.add_argument('--tgt_len', type=int, default=128,
|
parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
|
||||||
help='number of tokens to predict')
|
parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
|
||||||
parser.add_argument('--ext_len', type=int, default=0,
|
parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
|
||||||
help='length of the extended context')
|
parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
|
||||||
parser.add_argument('--mem_len', type=int, default=1600,
|
parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
|
||||||
help='length of the retained previous heads')
|
parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
|
||||||
parser.add_argument('--clamp_len', type=int, default=1000,
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
help='max positional embedding index')
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument('--no_cuda', action='store_true',
|
|
||||||
help='Do not use CUDA even though CUA is available')
|
|
||||||
parser.add_argument('--work_dir', type=str, required=True,
|
|
||||||
help='path to the work_dir')
|
|
||||||
parser.add_argument('--no_log', action='store_true',
|
|
||||||
help='do not log the eval result')
|
|
||||||
parser.add_argument('--same_length', action='store_true',
|
|
||||||
help='set same length attention with masking')
|
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
assert args.ext_len >= 0, "extended context length must be non-negative"
|
||||||
|
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -84,17 +76,18 @@ def main():
|
|||||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||||
ntokens = len(corpus.vocab)
|
ntokens = len(corpus.vocab)
|
||||||
|
|
||||||
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
|
va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
|
||||||
device=device, ext_len=args.ext_len)
|
te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
|
||||||
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
|
||||||
device=device, ext_len=args.ext_len)
|
|
||||||
|
|
||||||
# Load a pre-trained model
|
# Load a pre-trained model
|
||||||
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
|
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
logger.info(
|
||||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
"Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
|
||||||
|
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||||
if args.clamp_len > 0:
|
if args.clamp_len > 0:
|
||||||
@@ -108,7 +101,7 @@ def main():
|
|||||||
def evaluate(eval_iter):
|
def evaluate(eval_iter):
|
||||||
# Turn on evaluation mode which disables dropout.
|
# Turn on evaluation mode which disables dropout.
|
||||||
model.eval()
|
model.eval()
|
||||||
total_len, total_loss = 0, 0.
|
total_len, total_loss = 0, 0.0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mems = None
|
mems = None
|
||||||
@@ -119,35 +112,34 @@ def main():
|
|||||||
total_loss += seq_len * loss.item()
|
total_loss += seq_len * loss.item()
|
||||||
total_len += seq_len
|
total_len += seq_len
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
|
||||||
total_time, 1000 * total_time / (idx+1)))
|
|
||||||
return total_loss / total_len
|
return total_loss / total_len
|
||||||
|
|
||||||
# Run on test data.
|
# Run on test data.
|
||||||
if args.split == 'all':
|
if args.split == "all":
|
||||||
test_loss = evaluate(te_iter)
|
test_loss = evaluate(te_iter)
|
||||||
valid_loss = evaluate(va_iter)
|
valid_loss = evaluate(va_iter)
|
||||||
elif args.split == 'valid':
|
elif args.split == "valid":
|
||||||
valid_loss = evaluate(va_iter)
|
valid_loss = evaluate(va_iter)
|
||||||
test_loss = None
|
test_loss = None
|
||||||
elif args.split == 'test':
|
elif args.split == "test":
|
||||||
test_loss = evaluate(te_iter)
|
test_loss = evaluate(te_iter)
|
||||||
valid_loss = None
|
valid_loss = None
|
||||||
|
|
||||||
def format_log(loss, split):
|
def format_log(loss, split):
|
||||||
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
|
log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
|
||||||
split, loss, math.exp(loss))
|
|
||||||
return log_str
|
return log_str
|
||||||
|
|
||||||
log_str = ''
|
log_str = ""
|
||||||
if valid_loss is not None:
|
if valid_loss is not None:
|
||||||
log_str += format_log(valid_loss, 'valid')
|
log_str += format_log(valid_loss, "valid")
|
||||||
if test_loss is not None:
|
if test_loss is not None:
|
||||||
log_str += format_log(test_loss, 'test')
|
log_str += format_log(test_loss, "test")
|
||||||
|
|
||||||
logger.info('=' * 100)
|
logger.info("=" * 100)
|
||||||
logger.info(log_str)
|
logger.info(log_str)
|
||||||
logger.info('=' * 100)
|
logger.info("=" * 100)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -40,14 +40,12 @@ from utils import logger
|
|||||||
from lm_seqs_dataset import LmSeqsDataset
|
from lm_seqs_dataset import LmSeqsDataset
|
||||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||||
|
|
||||||
|
|
||||||
class Distiller:
|
class Distiller:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
params: dict,
|
self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
|
||||||
dataset: LmSeqsDataset,
|
):
|
||||||
token_probs: torch.tensor,
|
logger.info("Initializing Distiller")
|
||||||
student: nn.Module,
|
|
||||||
teacher: nn.Module):
|
|
||||||
logger.info('Initializing Distiller')
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self.dump_path = params.dump_path
|
self.dump_path = params.dump_path
|
||||||
self.multi_gpu = params.multi_gpu
|
self.multi_gpu = params.multi_gpu
|
||||||
@@ -70,12 +68,10 @@ class Distiller:
|
|||||||
else:
|
else:
|
||||||
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
|
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
|
||||||
|
|
||||||
self.dataloader = DataLoader(dataset=dataset,
|
self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
|
||||||
batch_sampler=sampler,
|
|
||||||
collate_fn=dataset.batch_sequences)
|
|
||||||
|
|
||||||
self.temperature = params.temperature
|
self.temperature = params.temperature
|
||||||
assert self.temperature > 0.
|
assert self.temperature > 0.0
|
||||||
|
|
||||||
self.alpha_ce = params.alpha_ce
|
self.alpha_ce = params.alpha_ce
|
||||||
self.alpha_mlm = params.alpha_mlm
|
self.alpha_mlm = params.alpha_mlm
|
||||||
@@ -85,18 +81,18 @@ class Distiller:
|
|||||||
|
|
||||||
self.mlm = params.mlm
|
self.mlm = params.mlm
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
logger.info(f'Using MLM loss for LM step.')
|
logger.info(f"Using MLM loss for LM step.")
|
||||||
self.mlm_mask_prop = params.mlm_mask_prop
|
self.mlm_mask_prop = params.mlm_mask_prop
|
||||||
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
assert 0.0 <= self.mlm_mask_prop <= 1.0
|
||||||
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
assert params.word_mask + params.word_keep + params.word_rand == 1.0
|
||||||
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
|
||||||
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
|
self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
|
||||||
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
|
self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
self.pred_probs = self.pred_probs.half()
|
self.pred_probs = self.pred_probs.half()
|
||||||
self.token_probs = self.token_probs.half()
|
self.token_probs = self.token_probs.half()
|
||||||
else:
|
else:
|
||||||
logger.info(f'Using CLM loss for LM step.')
|
logger.info(f"Using CLM loss for LM step.")
|
||||||
|
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.n_iter = 0
|
self.n_iter = 0
|
||||||
@@ -107,38 +103,54 @@ class Distiller:
|
|||||||
self.last_loss_ce = 0
|
self.last_loss_ce = 0
|
||||||
self.last_loss_mlm = 0
|
self.last_loss_mlm = 0
|
||||||
self.last_loss_clm = 0
|
self.last_loss_clm = 0
|
||||||
if self.alpha_mse > 0.: self.last_loss_mse = 0
|
if self.alpha_mse > 0.0:
|
||||||
if self.alpha_cos > 0.: self.last_loss_cos = 0
|
self.last_loss_mse = 0
|
||||||
|
if self.alpha_cos > 0.0:
|
||||||
|
self.last_loss_cos = 0
|
||||||
self.last_log = 0
|
self.last_log = 0
|
||||||
|
|
||||||
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
|
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
||||||
if self.alpha_mse > 0.:
|
if self.alpha_mse > 0.0:
|
||||||
self.mse_loss_fct = nn.MSELoss(reduction='sum')
|
self.mse_loss_fct = nn.MSELoss(reduction="sum")
|
||||||
if self.alpha_cos > 0.:
|
if self.alpha_cos > 0.0:
|
||||||
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')
|
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
|
||||||
|
|
||||||
logger.info('--- Initializing model optimizer')
|
logger.info("--- Initializing model optimizer")
|
||||||
assert params.gradient_accumulation_steps >= 1
|
assert params.gradient_accumulation_steps >= 1
|
||||||
self.num_steps_epoch = len(self.dataloader)
|
self.num_steps_epoch = len(self.dataloader)
|
||||||
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
num_train_optimization_steps = (
|
||||||
|
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
|
||||||
|
)
|
||||||
|
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
|
{
|
||||||
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
|
"params": [
|
||||||
|
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
|
||||||
|
],
|
||||||
|
"weight_decay": params.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
|
logger.info(
|
||||||
|
"------ Number of trainable parameters (student): %i"
|
||||||
|
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
|
||||||
|
)
|
||||||
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
|
||||||
self.optimizer = AdamW(optimizer_grouped_parameters,
|
self.optimizer = AdamW(
|
||||||
lr=params.learning_rate,
|
optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
|
||||||
eps=params.adam_epsilon,
|
)
|
||||||
betas=(0.9, 0.98))
|
|
||||||
|
|
||||||
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
|
||||||
self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
|
self.scheduler = get_linear_schedule_with_warmup(
|
||||||
num_warmup_steps=warmup_steps,
|
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
|
||||||
num_training_steps=num_train_optimization_steps)
|
)
|
||||||
|
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
try:
|
try:
|
||||||
@@ -146,33 +158,36 @@ class Distiller:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
|
||||||
self.student, self.optimizer = amp.initialize(self.student,
|
self.student, self.optimizer = amp.initialize(
|
||||||
self.optimizer,
|
self.student, self.optimizer, opt_level=self.params.fp16_opt_level
|
||||||
opt_level=self.params.fp16_opt_level)
|
)
|
||||||
self.teacher = self.teacher.half()
|
self.teacher = self.teacher.half()
|
||||||
|
|
||||||
if self.multi_gpu:
|
if self.multi_gpu:
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
from apex.parallel import DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
|
|
||||||
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
|
||||||
self.student = DistributedDataParallel(self.student)
|
self.student = DistributedDataParallel(self.student)
|
||||||
else:
|
else:
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
|
||||||
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
|
||||||
self.student = DistributedDataParallel(self.student,
|
self.student = DistributedDataParallel(
|
||||||
|
self.student,
|
||||||
device_ids=[params.local_rank],
|
device_ids=[params.local_rank],
|
||||||
output_device=params.local_rank,
|
output_device=params.local_rank,
|
||||||
find_unused_parameters=True)
|
find_unused_parameters=True,
|
||||||
|
)
|
||||||
|
|
||||||
self.is_master = params.is_master
|
self.is_master = params.is_master
|
||||||
if self.is_master:
|
if self.is_master:
|
||||||
logger.info('--- Initializing Tensorboard')
|
logger.info("--- Initializing Tensorboard")
|
||||||
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
|
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
|
||||||
self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
|
self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
|
||||||
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)
|
self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
|
||||||
|
|
||||||
def prepare_batch_mlm(self,
|
def prepare_batch_mlm(self, batch):
|
||||||
batch):
|
|
||||||
"""
|
"""
|
||||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
|
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
|
||||||
|
|
||||||
@@ -192,7 +207,7 @@ class Distiller:
|
|||||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||||
assert token_ids.size(0) == lengths.size(0)
|
assert token_ids.size(0) == lengths.size(0)
|
||||||
|
|
||||||
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||||
|
|
||||||
bs, max_seq_len = token_ids.size()
|
bs, max_seq_len = token_ids.size()
|
||||||
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||||
@@ -200,11 +215,13 @@ class Distiller:
|
|||||||
x_prob = self.token_probs[token_ids.flatten()]
|
x_prob = self.token_probs[token_ids.flatten()]
|
||||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||||
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
pred_mask = torch.zeros(
|
||||||
|
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
|
||||||
|
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||||
pred_mask[tgt_ids] = 1
|
pred_mask[tgt_ids] = 1
|
||||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||||
|
|
||||||
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0
|
pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
|
||||||
|
|
||||||
# mask a number of words == 0 [8] (faster with fp16)
|
# mask a number of words == 0 [8] (faster with fp16)
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
@@ -219,9 +236,13 @@ class Distiller:
|
|||||||
|
|
||||||
_token_ids_real = token_ids[pred_mask]
|
_token_ids_real = token_ids[pred_mask]
|
||||||
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
|
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
|
||||||
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
|
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
|
||||||
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
|
||||||
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
_token_ids = (
|
||||||
|
_token_ids_mask * (probs == 0).long()
|
||||||
|
+ _token_ids_real * (probs == 1).long()
|
||||||
|
+ _token_ids_rand * (probs == 2).long()
|
||||||
|
)
|
||||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||||
|
|
||||||
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||||
@@ -231,8 +252,7 @@ class Distiller:
|
|||||||
|
|
||||||
return token_ids, attn_mask, mlm_labels
|
return token_ids, attn_mask, mlm_labels
|
||||||
|
|
||||||
def prepare_batch_clm(self,
|
def prepare_batch_clm(self, batch):
|
||||||
batch):
|
|
||||||
"""
|
"""
|
||||||
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
|
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
|
||||||
|
|
||||||
@@ -252,7 +272,7 @@ class Distiller:
|
|||||||
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
|
||||||
assert token_ids.size(0) == lengths.size(0)
|
assert token_ids.size(0) == lengths.size(0)
|
||||||
|
|
||||||
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
|
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
|
||||||
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
|
||||||
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||||
|
|
||||||
@@ -261,9 +281,7 @@ class Distiller:
|
|||||||
|
|
||||||
return token_ids, attn_mask, clm_labels
|
return token_ids, attn_mask, clm_labels
|
||||||
|
|
||||||
def round_batch(self,
|
def round_batch(self, x: torch.tensor, lengths: torch.tensor):
|
||||||
x: torch.tensor,
|
|
||||||
lengths: torch.tensor):
|
|
||||||
"""
|
"""
|
||||||
For float16 only.
|
For float16 only.
|
||||||
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
|
||||||
@@ -299,9 +317,9 @@ class Distiller:
|
|||||||
pad = 8 - (ml1 % 8)
|
pad = 8 - (ml1 % 8)
|
||||||
ml2 = ml1 + pad
|
ml2 = ml1 + pad
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
pad_id = self.params.special_tok_ids['pad_token']
|
pad_id = self.params.special_tok_ids["pad_token"]
|
||||||
else:
|
else:
|
||||||
pad_id = self.params.special_tok_ids['unk_token']
|
pad_id = self.params.special_tok_ids["unk_token"]
|
||||||
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
|
||||||
x = torch.cat([x, padding_tensor], 1)
|
x = torch.cat([x, padding_tensor], 1)
|
||||||
assert x.size() == (bs2, ml2)
|
assert x.size() == (bs2, ml2)
|
||||||
@@ -314,20 +332,22 @@ class Distiller:
|
|||||||
"""
|
"""
|
||||||
The real training loop.
|
The real training loop.
|
||||||
"""
|
"""
|
||||||
if self.is_master: logger.info('Starting training')
|
if self.is_master:
|
||||||
|
logger.info("Starting training")
|
||||||
self.last_log = time.time()
|
self.last_log = time.time()
|
||||||
self.student.train()
|
self.student.train()
|
||||||
self.teacher.eval()
|
self.teacher.eval()
|
||||||
|
|
||||||
for _ in range(self.params.n_epoch):
|
for _ in range(self.params.n_epoch):
|
||||||
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
|
if self.is_master:
|
||||||
|
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||||
if self.multi_gpu:
|
if self.multi_gpu:
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
|
||||||
for batch in iter_bar:
|
for batch in iter_bar:
|
||||||
if self.params.n_gpu > 0:
|
if self.params.n_gpu > 0:
|
||||||
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
|
batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
|
||||||
|
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
|
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
|
||||||
@@ -336,22 +356,21 @@ class Distiller:
|
|||||||
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
|
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
|
||||||
|
|
||||||
iter_bar.update()
|
iter_bar.update()
|
||||||
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
|
iter_bar.set_postfix(
|
||||||
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
|
{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
|
||||||
|
)
|
||||||
iter_bar.close()
|
iter_bar.close()
|
||||||
|
|
||||||
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
|
if self.is_master:
|
||||||
|
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
|
||||||
self.end_epoch()
|
self.end_epoch()
|
||||||
|
|
||||||
if self.is_master:
|
if self.is_master:
|
||||||
logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
|
logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
|
||||||
self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
|
self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
|
||||||
logger.info('Training is finished')
|
logger.info("Training is finished")
|
||||||
|
|
||||||
def step(self,
|
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
|
||||||
input_ids: torch.tensor,
|
|
||||||
attention_mask: torch.tensor,
|
|
||||||
lm_labels: torch.tensor):
|
|
||||||
"""
|
"""
|
||||||
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
|
||||||
and possibly a parameter update (depending on the gradient accumulation).
|
and possibly a parameter update (depending on the gradient accumulation).
|
||||||
@@ -363,13 +382,21 @@ class Distiller:
|
|||||||
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
|
||||||
"""
|
"""
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
s_logits, s_hidden_states = self.student(
|
||||||
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
|
) # (bs, seq_length, voc_size)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
|
t_logits, t_hidden_states = self.teacher(
|
||||||
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
|
) # (bs, seq_length, voc_size)
|
||||||
else:
|
else:
|
||||||
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
s_logits, _, s_hidden_states = self.student(
|
||||||
|
input_ids=input_ids, attention_mask=None
|
||||||
|
) # (bs, seq_length, voc_size)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
|
t_logits, _, t_hidden_states = self.teacher(
|
||||||
|
input_ids=input_ids, attention_mask=None
|
||||||
|
) # (bs, seq_length, voc_size)
|
||||||
assert s_logits.size() == t_logits.size()
|
assert s_logits.size() == t_logits.size()
|
||||||
|
|
||||||
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||||
@@ -384,24 +411,30 @@ class Distiller:
|
|||||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
assert t_logits_slct.size() == s_logits_slct.size()
|
assert t_logits_slct.size() == s_logits_slct.size()
|
||||||
|
|
||||||
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
|
loss_ce = (
|
||||||
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
|
self.ce_loss_fct(
|
||||||
|
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||||
|
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||||
|
)
|
||||||
|
* (self.temperature) ** 2
|
||||||
|
)
|
||||||
loss = self.alpha_ce * loss_ce
|
loss = self.alpha_ce * loss_ce
|
||||||
|
|
||||||
if self.alpha_mlm > 0.:
|
if self.alpha_mlm > 0.0:
|
||||||
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
|
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
|
||||||
loss += self.alpha_mlm * loss_mlm
|
loss += self.alpha_mlm * loss_mlm
|
||||||
if self.alpha_clm > 0.:
|
if self.alpha_clm > 0.0:
|
||||||
shift_logits = s_logits[..., :-1, :].contiguous()
|
shift_logits = s_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = lm_labels[..., 1:].contiguous()
|
shift_labels = lm_labels[..., 1:].contiguous()
|
||||||
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
shift_labels.view(-1))
|
|
||||||
loss += self.alpha_clm * loss_clm
|
loss += self.alpha_clm * loss_clm
|
||||||
|
|
||||||
if self.alpha_mse > 0.:
|
if self.alpha_mse > 0.0:
|
||||||
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
|
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
|
||||||
|
0
|
||||||
|
) # Reproducing batchmean reduction
|
||||||
loss += self.alpha_mse * loss_mse
|
loss += self.alpha_mse * loss_mse
|
||||||
if self.alpha_cos > 0.:
|
if self.alpha_cos > 0.0:
|
||||||
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
|
||||||
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
|
||||||
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
|
||||||
@@ -420,21 +453,20 @@ class Distiller:
|
|||||||
self.total_loss_epoch += loss.item()
|
self.total_loss_epoch += loss.item()
|
||||||
self.last_loss = loss.item()
|
self.last_loss = loss.item()
|
||||||
self.last_loss_ce = loss_ce.item()
|
self.last_loss_ce = loss_ce.item()
|
||||||
if self.alpha_mlm > 0.:
|
if self.alpha_mlm > 0.0:
|
||||||
self.last_loss_mlm = loss_mlm.item()
|
self.last_loss_mlm = loss_mlm.item()
|
||||||
if self.alpha_clm > 0.:
|
if self.alpha_clm > 0.0:
|
||||||
self.last_loss_clm = loss_clm.item()
|
self.last_loss_clm = loss_clm.item()
|
||||||
if self.alpha_mse > 0.:
|
if self.alpha_mse > 0.0:
|
||||||
self.last_loss_mse = loss_mse.item()
|
self.last_loss_mse = loss_mse.item()
|
||||||
if self.alpha_cos > 0.:
|
if self.alpha_cos > 0.0:
|
||||||
self.last_loss_cos = loss_cos.item()
|
self.last_loss_cos = loss_cos.item()
|
||||||
|
|
||||||
self.optimize(loss)
|
self.optimize(loss)
|
||||||
|
|
||||||
self.n_sequences_epoch += input_ids.size(0)
|
self.n_sequences_epoch += input_ids.size(0)
|
||||||
|
|
||||||
def optimize(self,
|
def optimize(self, loss):
|
||||||
loss):
|
|
||||||
"""
|
"""
|
||||||
Normalization on the loss (gradient accumulation or distributed training), followed by
|
Normalization on the loss (gradient accumulation or distributed training), followed by
|
||||||
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
|
||||||
@@ -442,7 +474,7 @@ class Distiller:
|
|||||||
"""
|
"""
|
||||||
# Check for NaN
|
# Check for NaN
|
||||||
if (loss != loss).data.any():
|
if (loss != loss).data.any():
|
||||||
logger.error('NaN detected')
|
logger.error("NaN detected")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
if self.multi_gpu:
|
if self.multi_gpu:
|
||||||
@@ -452,6 +484,7 @@ class Distiller:
|
|||||||
|
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
@@ -488,53 +521,84 @@ class Distiller:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for param_name, param in self.student.named_parameters():
|
for param_name, param in self.student.named_parameters():
|
||||||
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter)
|
tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
self.tensorboard.add_scalar(
|
||||||
|
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
|
||||||
|
)
|
||||||
if param.grad is None:
|
if param.grad is None:
|
||||||
continue
|
continue
|
||||||
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter)
|
tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
self.tensorboard.add_scalar(
|
||||||
|
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
|
||||||
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
|
tag="losses/cum_avg_loss_epoch",
|
||||||
|
scalar_value=self.total_loss_epoch / self.n_iter,
|
||||||
|
global_step=self.n_total_iter,
|
||||||
|
)
|
||||||
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
|
||||||
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
if self.alpha_mlm > 0.:
|
tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
|
||||||
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
|
)
|
||||||
if self.alpha_clm > 0.:
|
if self.alpha_mlm > 0.0:
|
||||||
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
if self.alpha_mse > 0.:
|
tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
|
||||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
)
|
||||||
if self.alpha_cos > 0.:
|
if self.alpha_clm > 0.0:
|
||||||
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
if self.alpha_mse > 0.0:
|
||||||
|
self.tensorboard.add_scalar(
|
||||||
|
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
if self.alpha_cos > 0.0:
|
||||||
|
self.tensorboard.add_scalar(
|
||||||
|
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
self.tensorboard.add_scalar(
|
||||||
|
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
|
||||||
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(
|
||||||
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter)
|
tag="global/memory_usage",
|
||||||
|
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
|
||||||
|
global_step=self.n_total_iter,
|
||||||
|
)
|
||||||
|
self.tensorboard.add_scalar(
|
||||||
|
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
|
||||||
|
)
|
||||||
|
|
||||||
def end_epoch(self):
|
def end_epoch(self):
|
||||||
"""
|
"""
|
||||||
Finally arrived at the end of epoch (full pass on dataset).
|
Finally arrived at the end of epoch (full pass on dataset).
|
||||||
Do some tensorboard logging and checkpoint saving.
|
Do some tensorboard logging and checkpoint saving.
|
||||||
"""
|
"""
|
||||||
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')
|
logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
|
||||||
|
|
||||||
if self.is_master:
|
if self.is_master:
|
||||||
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
|
self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
|
||||||
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)
|
self.tensorboard.add_scalar(
|
||||||
|
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
|
||||||
|
)
|
||||||
|
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self.n_sequences_epoch = 0
|
self.n_sequences_epoch = 0
|
||||||
self.n_iter = 0
|
self.n_iter = 0
|
||||||
self.total_loss_epoch = 0
|
self.total_loss_epoch = 0
|
||||||
|
|
||||||
def save_checkpoint(self,
|
def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
|
||||||
checkpoint_name: str = 'checkpoint.pth'):
|
|
||||||
"""
|
"""
|
||||||
Save the current state. Only by the master process.
|
Save the current state. Only by the master process.
|
||||||
"""
|
"""
|
||||||
if not self.is_master:
|
if not self.is_master:
|
||||||
return
|
return
|
||||||
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student
|
mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
|
||||||
mdl_to_save.config.save_pretrained(self.dump_path)
|
mdl_to_save.config.save_pretrained(self.dump_path)
|
||||||
state_dict = mdl_to_save.state_dict()
|
state_dict = mdl_to_save.state_dict()
|
||||||
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
|
||||||
|
|||||||
@@ -23,12 +23,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler
|
|||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
|
||||||
|
|
||||||
def _quantize(x, bins):
|
def _quantize(x, bins):
|
||||||
bins = copy.deepcopy(bins)
|
bins = copy.deepcopy(bins)
|
||||||
bins = sorted(bins)
|
bins = sorted(bins)
|
||||||
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
||||||
return quantized
|
return quantized
|
||||||
|
|
||||||
|
|
||||||
def create_lengths_groups(lengths, k=0):
|
def create_lengths_groups(lengths, k=0):
|
||||||
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
|
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
|
||||||
groups = _quantize(lengths, bins)
|
groups = _quantize(lengths, bins)
|
||||||
@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0):
|
|||||||
logger.info("Count of instances per bin: {}".format(counts))
|
logger.info("Count of instances per bin: {}".format(counts))
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
|
||||||
class GroupedBatchSampler(BatchSampler):
|
class GroupedBatchSampler(BatchSampler):
|
||||||
"""
|
"""
|
||||||
Wraps another sampler to yield a mini-batch of indices.
|
Wraps another sampler to yield a mini-batch of indices.
|
||||||
@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler):
|
|||||||
0, i.e. they must be in the range [0, num_groups).
|
0, i.e. they must be in the range [0, num_groups).
|
||||||
batch_size (int): Size of mini-batch.
|
batch_size (int): Size of mini-batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sampler, group_ids, batch_size):
|
def __init__(self, sampler, group_ids, batch_size):
|
||||||
if not isinstance(sampler, Sampler):
|
if not isinstance(sampler, Sampler):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"sampler should be an instance of "
|
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||||
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
|
||||||
)
|
)
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.group_ids = group_ids
|
self.group_ids = group_ids
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from torch.utils.data import Dataset
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
|
||||||
|
|
||||||
class LmSeqsDataset(Dataset):
|
class LmSeqsDataset(Dataset):
|
||||||
"""Custom Dataset wrapping language modeling sequences.
|
"""Custom Dataset wrapping language modeling sequences.
|
||||||
|
|
||||||
@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset):
|
|||||||
data: `List[np.array[int]]
|
data: `List[np.array[int]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, params, data):
|
||||||
params,
|
|
||||||
data):
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
self.token_ids = np.array(data)
|
self.token_ids = np.array(data)
|
||||||
@@ -65,7 +64,7 @@ class LmSeqsDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
max_len = self.params.max_model_input_size
|
max_len = self.params.max_model_input_size
|
||||||
indices = self.lengths > max_len
|
indices = self.lengths > max_len
|
||||||
logger.info(f'Splitting {sum(indices)} too long sequences.')
|
logger.info(f"Splitting {sum(indices)} too long sequences.")
|
||||||
|
|
||||||
def divide_chunks(l, n):
|
def divide_chunks(l, n):
|
||||||
return [l[i : i + n] for i in range(0, len(l), n)]
|
return [l[i : i + n] for i in range(0, len(l), n)]
|
||||||
@@ -73,9 +72,9 @@ class LmSeqsDataset(Dataset):
|
|||||||
new_tok_ids = []
|
new_tok_ids = []
|
||||||
new_lengths = []
|
new_lengths = []
|
||||||
if self.params.mlm:
|
if self.params.mlm:
|
||||||
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
|
cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
|
||||||
else:
|
else:
|
||||||
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token']
|
cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
|
||||||
|
|
||||||
for seq_, len_ in zip(self.token_ids, self.lengths):
|
for seq_, len_ in zip(self.token_ids, self.lengths):
|
||||||
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
|
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
|
||||||
@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset):
|
|||||||
self.token_ids = self.token_ids[indices]
|
self.token_ids = self.token_ids[indices]
|
||||||
self.lengths = self.lengths[indices]
|
self.lengths = self.lengths[indices]
|
||||||
new_size = len(self)
|
new_size = len(self)
|
||||||
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.')
|
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
|
||||||
|
|
||||||
def print_statistics(self):
|
def print_statistics(self):
|
||||||
"""
|
"""
|
||||||
@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
if not self.params.is_master:
|
if not self.params.is_master:
|
||||||
return
|
return
|
||||||
logger.info(f'{len(self)} sequences')
|
logger.info(f"{len(self)} sequences")
|
||||||
# data_len = sum(self.lengths)
|
# data_len = sum(self.lengths)
|
||||||
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
|
||||||
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
|
||||||
@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset):
|
|||||||
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
|
||||||
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
|
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
|
||||||
|
|
||||||
def batch_sequences(self,
|
def batch_sequences(self, batch):
|
||||||
batch):
|
|
||||||
"""
|
"""
|
||||||
Do the padding and transform into torch.tensor.
|
Do the padding and transform into torch.tensor.
|
||||||
"""
|
"""
|
||||||
@@ -139,9 +137,9 @@ class LmSeqsDataset(Dataset):
|
|||||||
|
|
||||||
# Pad token ids
|
# Pad token ids
|
||||||
if self.params.mlm:
|
if self.params.mlm:
|
||||||
pad_idx = self.params.special_tok_ids['pad_token']
|
pad_idx = self.params.special_tok_ids["pad_token"]
|
||||||
else:
|
else:
|
||||||
pad_idx = self.params.special_tok_ids['unk_token']
|
pad_idx = self.params.special_tok_ids["unk_token"]
|
||||||
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
|
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
|
||||||
assert len(tk_) == len(token_ids)
|
assert len(tk_) == len(token_ids)
|
||||||
assert all(len(t) == max_seq_len_ for t in tk_)
|
assert all(len(t) == max_seq_len_ for t in tk_)
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ import glob
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -38,19 +37,32 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (
|
||||||
BertForQuestionAnswering, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
XLMConfig, XLMForQuestionAnswering,
|
BertConfig,
|
||||||
XLMTokenizer, XLNetConfig,
|
BertForQuestionAnswering,
|
||||||
|
BertTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
DistilBertConfig,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
from ..utils_squad import (read_squad_examples, convert_examples_to_features,
|
from ..utils_squad import (
|
||||||
RawResult, write_predictions,
|
read_squad_examples,
|
||||||
RawResultExtended, write_predictions_extended)
|
convert_examples_to_features,
|
||||||
|
RawResult,
|
||||||
|
write_predictions,
|
||||||
|
RawResultExtended,
|
||||||
|
write_predictions_extended,
|
||||||
|
)
|
||||||
|
|
||||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||||
# You can remove it from the dependencies if you are using this script outside of the library
|
# You can remove it from the dependencies if you are using this script outside of the library
|
||||||
@@ -59,16 +71,18 @@ from ..utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
ALL_MODELS = sum(
|
||||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -76,9 +90,11 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def to_list(tensor):
|
def to_list(tensor):
|
||||||
return tensor.detach().cpu().tolist()
|
return tensor.detach().cpu().tolist()
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@@ -95,13 +111,18 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -115,17 +136,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -141,35 +166,42 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
if teacher is not None:
|
if teacher is not None:
|
||||||
teacher.eval()
|
teacher.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1],
|
"input_ids": batch[0],
|
||||||
'start_positions': batch[3],
|
"attention_mask": batch[1],
|
||||||
'end_positions': batch[4]}
|
"start_positions": batch[3],
|
||||||
if args.model_type != 'distilbert':
|
"end_positions": batch[4],
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
}
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type != "distilbert":
|
||||||
inputs.update({'cls_index': batch[5],
|
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||||
'p_mask': batch[6]})
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss, start_logits_stu, end_logits_stu = outputs
|
loss, start_logits_stu, end_logits_stu = outputs
|
||||||
|
|
||||||
# Distillation loss
|
# Distillation loss
|
||||||
if teacher is not None:
|
if teacher is not None:
|
||||||
if 'token_type_ids' not in inputs:
|
if "token_type_ids" not in inputs:
|
||||||
inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2]
|
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'],
|
start_logits_tea, end_logits_tea = teacher(
|
||||||
token_type_ids=inputs['token_type_ids'],
|
input_ids=inputs["input_ids"],
|
||||||
attention_mask=inputs['attention_mask'])
|
token_type_ids=inputs["token_type_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
)
|
||||||
assert start_logits_tea.size() == start_logits_stu.size()
|
assert start_logits_tea.size() == start_logits_stu.size()
|
||||||
assert end_logits_tea.size() == end_logits_stu.size()
|
assert end_logits_tea.size() == end_logits_stu.size()
|
||||||
|
|
||||||
loss_fct = nn.KLDivLoss(reduction='batchmean')
|
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1),
|
loss_start = loss_fct(
|
||||||
F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
|
F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||||
loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1),
|
F.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||||
F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
|
) * (args.temperature ** 2)
|
||||||
loss_ce = (loss_start + loss_end)/2.
|
loss_end = loss_fct(
|
||||||
|
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||||
|
F.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||||
|
) * (args.temperature ** 2)
|
||||||
|
loss_ce = (loss_start + loss_end) / 2.0
|
||||||
|
|
||||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||||
|
|
||||||
@@ -195,22 +227,26 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -246,32 +282,31 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
model.eval()
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||||
'attention_mask': batch[1]
|
if args.model_type != "distilbert":
|
||||||
}
|
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
|
||||||
if args.model_type != 'distilbert':
|
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({'cls_index': batch[4],
|
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||||
'p_mask': batch[5]})
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
for i, example_index in enumerate(example_indices):
|
for i, example_index in enumerate(example_indices):
|
||||||
eval_feature = features[example_index.item()]
|
eval_feature = features[example_index.item()]
|
||||||
unique_id = int(eval_feature.unique_id)
|
unique_id = int(eval_feature.unique_id)
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
# XLNet uses a more complex post-processing procedure
|
# XLNet uses a more complex post-processing procedure
|
||||||
result = RawResultExtended(unique_id = unique_id,
|
result = RawResultExtended(
|
||||||
|
unique_id=unique_id,
|
||||||
start_top_log_probs=to_list(outputs[0][i]),
|
start_top_log_probs=to_list(outputs[0][i]),
|
||||||
start_top_index=to_list(outputs[1][i]),
|
start_top_index=to_list(outputs[1][i]),
|
||||||
end_top_log_probs=to_list(outputs[2][i]),
|
end_top_log_probs=to_list(outputs[2][i]),
|
||||||
end_top_index=to_list(outputs[3][i]),
|
end_top_index=to_list(outputs[3][i]),
|
||||||
cls_logits = to_list(outputs[4][i]))
|
cls_logits=to_list(outputs[4][i]),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = RawResult(unique_id = unique_id,
|
result = RawResult(
|
||||||
start_logits = to_list(outputs[0][i]),
|
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
|
||||||
end_logits = to_list(outputs[1][i]))
|
)
|
||||||
all_results.append(result)
|
all_results.append(result)
|
||||||
|
|
||||||
# Compute predictions
|
# Compute predictions
|
||||||
@@ -282,23 +317,44 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
else:
|
else:
|
||||||
output_null_log_odds_file = None
|
output_null_log_odds_file = None
|
||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
# XLNet uses a more complex post-processing procedure
|
# XLNet uses a more complex post-processing procedure
|
||||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
write_predictions_extended(
|
||||||
args.max_answer_length, output_prediction_file,
|
examples,
|
||||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
features,
|
||||||
model.config.start_n_top, model.config.end_n_top,
|
all_results,
|
||||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
args.n_best_size,
|
||||||
|
args.max_answer_length,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
args.predict_file,
|
||||||
|
model.config.start_n_top,
|
||||||
|
model.config.end_n_top,
|
||||||
|
args.version_2_with_negative,
|
||||||
|
tokenizer,
|
||||||
|
args.verbose_logging,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
write_predictions(examples, features, all_results, args.n_best_size,
|
write_predictions(
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
examples,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
features,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
all_results,
|
||||||
|
args.n_best_size,
|
||||||
|
args.max_answer_length,
|
||||||
|
args.do_lower_case,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
args.verbose_logging,
|
||||||
|
args.version_2_with_negative,
|
||||||
|
args.null_score_diff_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate with the official SQuAD script
|
# Evaluate with the official SQuAD script
|
||||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
evaluate_options = EVAL_OPTS(
|
||||||
pred_file=output_prediction_file,
|
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
|
||||||
na_prob_file=output_null_log_odds_file)
|
)
|
||||||
results = evaluate_on_squad(evaluate_options)
|
results = evaluate_on_squad(evaluate_options)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -309,24 +365,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
input_file = args.predict_file if evaluate else args.train_file
|
input_file = args.predict_file if evaluate else args.train_file
|
||||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
'dev' if evaluate else 'train',
|
os.path.dirname(input_file),
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
"cached_{}_{}_{}".format(
|
||||||
str(args.max_seq_length)))
|
"dev" if evaluate else "train",
|
||||||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
|
str(args.max_seq_length),
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_file)
|
logger.info("Creating features from dataset file at %s", input_file)
|
||||||
examples = read_squad_examples(input_file=input_file,
|
examples = read_squad_examples(
|
||||||
is_training=not evaluate,
|
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
|
||||||
version_2_with_negative=args.version_2_with_negative)
|
)
|
||||||
features = convert_examples_to_features(examples=examples,
|
features = convert_examples_to_features(
|
||||||
|
examples=examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate)
|
is_training=not evaluate,
|
||||||
|
)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
@@ -342,14 +404,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||||
if evaluate:
|
if evaluate:
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
dataset = TensorDataset(
|
||||||
all_example_index, all_cls_index, all_p_mask)
|
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
dataset = TensorDataset(
|
||||||
all_start_positions, all_end_positions,
|
all_input_ids,
|
||||||
all_cls_index, all_p_mask)
|
all_input_mask,
|
||||||
|
all_segment_ids,
|
||||||
|
all_start_positions,
|
||||||
|
all_end_positions,
|
||||||
|
all_cls_index,
|
||||||
|
all_p_mask,
|
||||||
|
)
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
return dataset, examples, features
|
return dataset, examples, features
|
||||||
@@ -360,121 +429,213 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
)
|
||||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
parser.add_argument(
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
"--predict_file",
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
default=None,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
type=str,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
required=True,
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
||||||
help="The output directory where the model checkpoints and predictions will be written.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
# Distillation parameters (optional)
|
# Distillation parameters (optional)
|
||||||
parser.add_argument('--teacher_type', default=None, type=str,
|
parser.add_argument(
|
||||||
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.")
|
"--teacher_type",
|
||||||
parser.add_argument('--teacher_name_or_path', default=None, type=str,
|
default=None,
|
||||||
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.")
|
type=str,
|
||||||
parser.add_argument('--alpha_ce', default=0.5, type=float,
|
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
|
||||||
help="Distillation loss linear weight. Only for distillation.")
|
)
|
||||||
parser.add_argument('--alpha_squad', default=0.5, type=float,
|
parser.add_argument(
|
||||||
help="True SQuAD loss linear weight. Only for distillation.")
|
"--teacher_name_or_path",
|
||||||
parser.add_argument('--temperature', default=2.0, type=float,
|
default=None,
|
||||||
help="Distillation temperature. Only for distillation.")
|
type=str,
|
||||||
|
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
"--tokenizer_name",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
parser.add_argument(
|
||||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
"--version_2_with_negative",
|
||||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
action="store_true",
|
||||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--null_score_diff_threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=384,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||||
parser.add_argument("--doc_stride", default=128, type=int,
|
)
|
||||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
parser.add_argument(
|
||||||
parser.add_argument("--max_query_length", default=64, type=int,
|
"--doc_stride",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_query_length",
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||||
"be truncated to this length.")
|
"be truncated to this length.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
help="Rul evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="The initial learning rate for Adam.")
|
parser.add_argument(
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
"--gradient_accumulation_steps",
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
type=int,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
default=1,
|
||||||
help="Weight deay if we apply some.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
)
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
parser.add_argument("--n_best_size", default=20, type=int,
|
type=int,
|
||||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_best_size",
|
||||||
|
default=20,
|
||||||
|
type=int,
|
||||||
|
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_answer_length",
|
||||||
|
default=30,
|
||||||
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another.")
|
"and end predictions are not conditioned on one another.",
|
||||||
parser.add_argument("--verbose_logging", action='store_true',
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose_logging",
|
||||||
|
action="store_true",
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help="If true, all of the warnings related to data processing will be printed. "
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Whether not to use CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||||
help="local_rank for distributed training on gpus")
|
parser.add_argument(
|
||||||
parser.add_argument('--fp16', action='store_true',
|
"--fp16",
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
action="store_true",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
)
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -486,16 +647,24 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -506,27 +675,34 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.teacher_type is not None:
|
if args.teacher_type is not None:
|
||||||
assert args.teacher_name_or_path is not None
|
assert args.teacher_name_or_path is not None
|
||||||
assert args.alpha_ce > 0.
|
assert args.alpha_ce > 0.0
|
||||||
assert args.alpha_ce + args.alpha_squad > 0.
|
assert args.alpha_ce + args.alpha_squad > 0.0
|
||||||
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT."
|
assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
|
||||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path,
|
teacher_config = teacher_config_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
|
||||||
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path,
|
)
|
||||||
config=teacher_config,
|
teacher = teacher_model_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||||
|
)
|
||||||
teacher.to(args.device)
|
teacher.to(args.device)
|
||||||
else:
|
else:
|
||||||
teacher = None
|
teacher = None
|
||||||
@@ -544,7 +720,6 @@ def main():
|
|||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Save the trained model and the tokenizer
|
# Save the trained model and the tokenizer
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -554,41 +729,44 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
|
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir,
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
do_lower_case=args.do_lower_case,
|
args.output_dir, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||||
|
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
|
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||||
|
|
||||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
logger.info("Results: {}".format(results))
|
logger.info("Results: {}".format(results))
|
||||||
|
|||||||
@@ -23,68 +23,65 @@ import numpy as np
|
|||||||
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
|
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
level = logging.INFO)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--file_path', type=str, default='data/dump.txt',
|
description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
|
||||||
help='The path to the data.')
|
)
|
||||||
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2'])
|
parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
|
||||||
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
|
parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
|
||||||
help="The tokenizer to use.")
|
parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
|
||||||
parser.add_argument('--dump_file', type=str, default='data/dump',
|
parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
|
||||||
help='The dump file prefix.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
|
||||||
logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
|
if args.tokenizer_type == "bert":
|
||||||
if args.tokenizer_type == 'bert':
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
|
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
|
||||||
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]`
|
bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
|
||||||
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]`
|
sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
|
||||||
elif args.tokenizer_type == 'roberta':
|
elif args.tokenizer_type == "roberta":
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
|
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
|
||||||
bos = tokenizer.special_tokens_map['cls_token'] # `<s>`
|
bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
|
||||||
sep = tokenizer.special_tokens_map['sep_token'] # `</s>`
|
sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
|
||||||
elif args.tokenizer_type == 'gpt2':
|
elif args.tokenizer_type == "gpt2":
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
|
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
|
||||||
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>`
|
bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
|
||||||
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>`
|
sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
|
||||||
|
|
||||||
logger.info(f'Loading text from {args.file_path}')
|
logger.info(f"Loading text from {args.file_path}")
|
||||||
with open(args.file_path, 'r', encoding='utf8') as fp:
|
with open(args.file_path, "r", encoding="utf8") as fp:
|
||||||
data = fp.readlines()
|
data = fp.readlines()
|
||||||
|
|
||||||
|
logger.info(f"Start encoding")
|
||||||
logger.info(f'Start encoding')
|
logger.info(f"{len(data)} examples to process.")
|
||||||
logger.info(f'{len(data)} examples to process.')
|
|
||||||
|
|
||||||
rslt = []
|
rslt = []
|
||||||
iter = 0
|
iter = 0
|
||||||
interval = 10000
|
interval = 10000
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for text in data:
|
for text in data:
|
||||||
text = f'{bos} {text.strip()} {sep}'
|
text = f"{bos} {text.strip()} {sep}"
|
||||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||||
rslt.append(token_ids)
|
rslt.append(token_ids)
|
||||||
|
|
||||||
iter += 1
|
iter += 1
|
||||||
if iter % interval == 0:
|
if iter % interval == 0:
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl')
|
logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
logger.info('Finished binarization')
|
logger.info("Finished binarization")
|
||||||
logger.info(f'{len(data)} examples processed.')
|
logger.info(f"{len(data)} examples processed.")
|
||||||
|
|
||||||
|
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
|
||||||
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle'
|
|
||||||
rslt_ = [np.uint16(d) for d in rslt]
|
rslt_ = [np.uint16(d) for d in rslt]
|
||||||
random.shuffle(rslt_)
|
random.shuffle(rslt_)
|
||||||
logger.info(f'Dump to {dp_file}')
|
logger.info(f"Dump to {dp_file}")
|
||||||
with open(dp_file, 'wb') as handle:
|
with open(dp_file, "wb") as handle:
|
||||||
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,70 +20,80 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
|
|||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
||||||
|
)
|
||||||
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
|
||||||
parser.add_argument("--model_name", default='roberta-large', type=str)
|
parser.add_argument("--model_name", default="roberta-large", type=str)
|
||||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str)
|
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
|
||||||
parser.add_argument("--vocab_transform", action='store_true')
|
parser.add_argument("--vocab_transform", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.model_type == "roberta":
|
||||||
if args.model_type == 'roberta':
|
|
||||||
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
model = RobertaForMaskedLM.from_pretrained(args.model_name)
|
||||||
prefix = 'roberta'
|
prefix = "roberta"
|
||||||
elif args.model_type == 'gpt2':
|
elif args.model_type == "gpt2":
|
||||||
model = GPT2LMHeadModel.from_pretrained(args.model_name)
|
model = GPT2LMHeadModel.from_pretrained(args.model_name)
|
||||||
prefix = 'transformer'
|
prefix = "transformer"
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
compressed_sd = {}
|
compressed_sd = {}
|
||||||
|
|
||||||
### Embeddings ###
|
### Embeddings ###
|
||||||
if args.model_type == 'gpt2':
|
if args.model_type == "gpt2":
|
||||||
for param_name in ['wte.weight', 'wpe.weight']:
|
for param_name in ["wte.weight", "wpe.weight"]:
|
||||||
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}']
|
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||||
else:
|
else:
|
||||||
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']:
|
for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
|
||||||
param_name = f'{prefix}.embeddings.{w}.weight'
|
param_name = f"{prefix}.embeddings.{w}.weight"
|
||||||
compressed_sd[param_name] = state_dict[param_name]
|
compressed_sd[param_name] = state_dict[param_name]
|
||||||
for w in ['weight', 'bias']:
|
for w in ["weight", "bias"]:
|
||||||
param_name = f'{prefix}.embeddings.LayerNorm.{w}'
|
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||||
compressed_sd[param_name] = state_dict[param_name]
|
compressed_sd[param_name] = state_dict[param_name]
|
||||||
|
|
||||||
### Transformer Blocks ###
|
### Transformer Blocks ###
|
||||||
std_idx = 0
|
std_idx = 0
|
||||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||||
if args.model_type == 'gpt2':
|
if args.model_type == "gpt2":
|
||||||
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']:
|
for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
|
||||||
for w in ['weight', 'bias']:
|
for w in ["weight", "bias"]:
|
||||||
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \
|
compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
|
||||||
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}']
|
f"{prefix}.h.{teacher_idx}.{layer}.{w}"
|
||||||
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias']
|
]
|
||||||
|
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
|
||||||
else:
|
else:
|
||||||
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value',
|
for layer in [
|
||||||
'attention.output.dense', 'attention.output.LayerNorm',
|
"attention.self.query",
|
||||||
'intermediate.dense', 'output.dense', 'output.LayerNorm']:
|
"attention.self.key",
|
||||||
for w in ['weight', 'bias']:
|
"attention.self.value",
|
||||||
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \
|
"attention.output.dense",
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}']
|
"attention.output.LayerNorm",
|
||||||
|
"intermediate.dense",
|
||||||
|
"output.dense",
|
||||||
|
"output.LayerNorm",
|
||||||
|
]:
|
||||||
|
for w in ["weight", "bias"]:
|
||||||
|
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
|
||||||
|
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
|
||||||
|
]
|
||||||
std_idx += 1
|
std_idx += 1
|
||||||
|
|
||||||
### Language Modeling Head ###s
|
### Language Modeling Head ###s
|
||||||
if args.model_type == 'roberta':
|
if args.model_type == "roberta":
|
||||||
for layer in ['lm_head.decoder.weight', 'lm_head.bias']:
|
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||||
compressed_sd[f'{layer}'] = state_dict[f'{layer}']
|
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||||
if args.vocab_transform:
|
if args.vocab_transform:
|
||||||
for w in ['weight', 'bias']:
|
for w in ["weight", "bias"]:
|
||||||
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}']
|
compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
|
||||||
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
|
compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
|
||||||
elif args.model_type == 'gpt2':
|
elif args.model_type == "gpt2":
|
||||||
for w in ['weight', 'bias']:
|
for w in ["weight", "bias"]:
|
||||||
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}']
|
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
|
||||||
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight']
|
compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
|
||||||
|
|
||||||
print(f'N layers selected for distillation: {std_idx}')
|
print(f"N layers selected for distillation: {std_idx}")
|
||||||
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||||
|
|
||||||
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
|
print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
|
||||||
torch.save(compressed_sd, args.dump_checkpoint)
|
torch.save(compressed_sd, args.dump_checkpoint)
|
||||||
|
|||||||
@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM
|
|||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
||||||
|
)
|
||||||
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
parser.add_argument("--model_type", default="bert", choices=["bert"])
|
||||||
parser.add_argument("--model_name", default='bert-base-uncased', type=str)
|
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
|
||||||
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
|
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
|
||||||
parser.add_argument("--vocab_transform", action='store_true')
|
parser.add_argument("--vocab_transform", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.model_type == "bert":
|
||||||
if args.model_type == 'bert':
|
|
||||||
model = BertForMaskedLM.from_pretrained(args.model_name)
|
model = BertForMaskedLM.from_pretrained(args.model_name)
|
||||||
prefix = 'bert'
|
prefix = "bert"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'args.model_type should be "bert".')
|
raise ValueError(f'args.model_type should be "bert".')
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
compressed_sd = {}
|
compressed_sd = {}
|
||||||
|
|
||||||
for w in ['word_embeddings', 'position_embeddings']:
|
for w in ["word_embeddings", "position_embeddings"]:
|
||||||
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
|
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
|
||||||
state_dict[f'{prefix}.embeddings.{w}.weight']
|
for w in ["weight", "bias"]:
|
||||||
for w in ['weight', 'bias']:
|
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
|
||||||
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
|
|
||||||
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
|
|
||||||
|
|
||||||
std_idx = 0
|
std_idx = 0
|
||||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||||
for w in ['weight', 'bias']:
|
for w in ["weight", "bias"]:
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}']
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
|
]
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}']
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}']
|
]
|
||||||
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
|
||||||
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
|
||||||
|
]
|
||||||
|
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
|
]
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
|
||||||
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
|
||||||
|
]
|
||||||
|
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
|
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
|
]
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}']
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
|
||||||
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
|
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
|
||||||
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
|
]
|
||||||
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
|
||||||
|
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
|
||||||
|
]
|
||||||
std_idx += 1
|
std_idx += 1
|
||||||
|
|
||||||
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
|
compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
|
||||||
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
|
compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
|
||||||
if args.vocab_transform:
|
if args.vocab_transform:
|
||||||
for w in ['weight', 'bias']:
|
for w in ["weight", "bias"]:
|
||||||
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
|
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
|
||||||
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
|
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
|
||||||
|
|
||||||
print(f'N layers selected for distillation: {std_idx}')
|
print(f"N layers selected for distillation: {std_idx}")
|
||||||
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
|
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
|
||||||
|
|
||||||
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
|
print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
|
||||||
torch.save(compressed_sd, args.dump_checkpoint)
|
torch.save(compressed_sd, args.dump_checkpoint)
|
||||||
|
|||||||
@@ -20,25 +20,29 @@ import argparse
|
|||||||
import pickle
|
import pickle
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
level = logging.INFO)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle",
|
description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
|
||||||
help="The binarized dataset.")
|
)
|
||||||
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle",
|
parser.add_argument(
|
||||||
help="The dump file.")
|
"--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
|
||||||
|
)
|
||||||
parser.add_argument("--vocab_size", default=30522, type=int)
|
parser.add_argument("--vocab_size", default=30522, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logger.info(f'Loading data from {args.data_file}')
|
logger.info(f"Loading data from {args.data_file}")
|
||||||
with open(args.data_file, 'rb') as fp:
|
with open(args.data_file, "rb") as fp:
|
||||||
data = pickle.load(fp)
|
data = pickle.load(fp)
|
||||||
|
|
||||||
logger.info('Counting occurences for MLM.')
|
logger.info("Counting occurences for MLM.")
|
||||||
counter = Counter()
|
counter = Counter()
|
||||||
for tk_ids in data:
|
for tk_ids in data:
|
||||||
counter.update(tk_ids)
|
counter.update(tk_ids)
|
||||||
@@ -46,6 +50,6 @@ if __name__ == '__main__':
|
|||||||
for k, v in counter.items():
|
for k, v in counter.items():
|
||||||
counts[k] = v
|
counts[k] = v
|
||||||
|
|
||||||
logger.info(f'Dump to {args.token_counts_dump}')
|
logger.info(f"Dump to {args.token_counts_dump}")
|
||||||
with open(args.token_counts_dump, 'wb') as handle:
|
with open(args.token_counts_dump, "wb") as handle:
|
||||||
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|||||||
@@ -35,166 +35,200 @@ from lm_seqs_dataset import LmSeqsDataset
|
|||||||
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
|
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def sanity_checks(args):
|
def sanity_checks(args):
|
||||||
"""
|
"""
|
||||||
A bunch of args sanity checks to perform even starting...
|
A bunch of args sanity checks to perform even starting...
|
||||||
"""
|
"""
|
||||||
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.)
|
assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
|
||||||
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.)
|
assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
|
||||||
if args.mlm:
|
if args.mlm:
|
||||||
assert os.path.isfile(args.token_counts)
|
assert os.path.isfile(args.token_counts)
|
||||||
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert'])
|
assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
|
||||||
else:
|
else:
|
||||||
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2'])
|
assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
|
||||||
|
|
||||||
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert')
|
assert args.teacher_type == args.student_type or (
|
||||||
|
args.student_type == "distilbert" and args.teacher_type == "bert"
|
||||||
|
)
|
||||||
assert os.path.isfile(args.student_config)
|
assert os.path.isfile(args.student_config)
|
||||||
if args.student_pretrained_weights is not None:
|
if args.student_pretrained_weights is not None:
|
||||||
assert os.path.isfile(args.student_pretrained_weights)
|
assert os.path.isfile(args.student_pretrained_weights)
|
||||||
|
|
||||||
if args.freeze_token_type_embds: assert args.student_type in ['roberta']
|
if args.freeze_token_type_embds:
|
||||||
|
assert args.student_type in ["roberta"]
|
||||||
|
|
||||||
|
assert args.alpha_ce >= 0.0
|
||||||
|
assert args.alpha_mlm >= 0.0
|
||||||
|
assert args.alpha_clm >= 0.0
|
||||||
|
assert args.alpha_mse >= 0.0
|
||||||
|
assert args.alpha_cos >= 0.0
|
||||||
|
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
|
||||||
|
|
||||||
assert args.alpha_ce >= 0.
|
|
||||||
assert args.alpha_mlm >= 0.
|
|
||||||
assert args.alpha_clm >= 0.
|
|
||||||
assert args.alpha_mse >= 0.
|
|
||||||
assert args.alpha_cos >= 0.
|
|
||||||
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
|
|
||||||
|
|
||||||
def freeze_pos_embeddings(student, args):
|
def freeze_pos_embeddings(student, args):
|
||||||
if args.student_type == 'roberta':
|
if args.student_type == "roberta":
|
||||||
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
||||||
elif args.student_type == 'gpt2':
|
elif args.student_type == "gpt2":
|
||||||
student.transformer.wpe.weight.requires_grad = False
|
student.transformer.wpe.weight.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
def freeze_token_type_embeddings(student, args):
|
def freeze_token_type_embeddings(student, args):
|
||||||
if args.student_type == 'roberta':
|
if args.student_type == "roberta":
|
||||||
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Training")
|
parser = argparse.ArgumentParser(description="Training")
|
||||||
parser.add_argument("--force", action='store_true',
|
parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
|
||||||
help="Overwrite dump_path if it already exists.")
|
|
||||||
|
|
||||||
parser.add_argument("--dump_path", type=str, required=True,
|
parser.add_argument(
|
||||||
help="The output directory (log, checkpoints, parameters, etc.)")
|
"--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
|
||||||
parser.add_argument("--data_file", type=str, required=True,
|
)
|
||||||
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
|
parser.add_argument(
|
||||||
|
"--data_file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True,
|
parser.add_argument(
|
||||||
help="The student type (DistilBERT, RoBERTa).")
|
"--student_type",
|
||||||
parser.add_argument("--student_config", type=str, required=True,
|
type=str,
|
||||||
help="Path to the student configuration.")
|
choices=["distilbert", "roberta", "gpt2"],
|
||||||
parser.add_argument("--student_pretrained_weights", default=None, type=str,
|
required=True,
|
||||||
help="Load student initialization checkpoint.")
|
help="The student type (DistilBERT, RoBERTa).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True,
|
parser.add_argument(
|
||||||
help="Teacher type (BERT, RoBERTa).")
|
"--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
|
||||||
parser.add_argument("--teacher_name", type=str, required=True,
|
)
|
||||||
help="The teacher model.")
|
parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
|
||||||
|
|
||||||
parser.add_argument("--temperature", default=2., type=float,
|
parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
|
||||||
help="Temperature for the softmax temperature.")
|
parser.add_argument(
|
||||||
parser.add_argument("--alpha_ce", default=0.5, type=float,
|
"--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
|
||||||
help="Linear weight for the distillation loss. Must be >=0.")
|
)
|
||||||
parser.add_argument("--alpha_mlm", default=0.0, type=float,
|
parser.add_argument(
|
||||||
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.")
|
"--alpha_mlm",
|
||||||
parser.add_argument("--alpha_clm", default=0.5, type=float,
|
default=0.0,
|
||||||
help="Linear weight for the CLM loss. Must be >=0.")
|
type=float,
|
||||||
parser.add_argument("--alpha_mse", default=0.0, type=float,
|
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
|
||||||
help="Linear weight of the MSE loss. Must be >=0.")
|
)
|
||||||
parser.add_argument("--alpha_cos", default=0.0, type=float,
|
parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
|
||||||
help="Linear weight of the cosine embedding loss. Must be >=0.")
|
parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--mlm", action="store_true",
|
parser.add_argument(
|
||||||
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
|
"--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
|
||||||
parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
|
)
|
||||||
help="Proportion of tokens for which we need to make a prediction.")
|
parser.add_argument(
|
||||||
parser.add_argument("--word_mask", default=0.8, type=float,
|
"--mlm_mask_prop",
|
||||||
help="Proportion of tokens to mask out.")
|
default=0.15,
|
||||||
parser.add_argument("--word_keep", default=0.1, type=float,
|
type=float,
|
||||||
help="Proportion of tokens to keep.")
|
help="Proportion of tokens for which we need to make a prediction.",
|
||||||
parser.add_argument("--word_rand", default=0.1, type=float,
|
)
|
||||||
help="Proportion of tokens to randomly replace.")
|
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
|
||||||
parser.add_argument("--mlm_smoothing", default=0.7, type=float,
|
parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
|
||||||
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
|
parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
|
||||||
parser.add_argument("--token_counts", type=str,
|
parser.add_argument(
|
||||||
help="The token counts in the data_file for MLM.")
|
"--mlm_smoothing",
|
||||||
|
default=0.7,
|
||||||
|
type=float,
|
||||||
|
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
|
||||||
|
|
||||||
parser.add_argument("--restrict_ce_to_mask", action='store_true',
|
parser.add_argument(
|
||||||
help="If true, compute the distilation loss only the [MLM] prediction distribution.")
|
"--restrict_ce_to_mask",
|
||||||
parser.add_argument("--freeze_pos_embs", action="store_true",
|
action="store_true",
|
||||||
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.")
|
help="If true, compute the distilation loss only the [MLM] prediction distribution.",
|
||||||
parser.add_argument("--freeze_token_type_embds", action="store_true",
|
)
|
||||||
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.")
|
parser.add_argument(
|
||||||
|
"--freeze_pos_embs",
|
||||||
|
action="store_true",
|
||||||
|
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze_token_type_embds",
|
||||||
|
action="store_true",
|
||||||
|
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--n_epoch", type=int, default=3,
|
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
|
||||||
help="Number of pass on the whole dataset.")
|
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
|
||||||
parser.add_argument("--batch_size", type=int, default=5,
|
parser.add_argument(
|
||||||
help="Batch size (for each process).")
|
"--group_by_size",
|
||||||
parser.add_argument("--group_by_size", action='store_false',
|
action="store_false",
|
||||||
help="If true, group sequences that have similar length into the same batch. Default is true.")
|
help="If true, group sequences that have similar length into the same batch. Default is true.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=50,
|
parser.add_argument(
|
||||||
help="Gradient accumulation for larger training batches.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--warmup_prop", default=0.05, type=float,
|
type=int,
|
||||||
help="Linear warmup proportion.")
|
default=50,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Gradient accumulation for larger training batches.",
|
||||||
help="Weight deay if we apply some.")
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-4, type=float,
|
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
|
||||||
help="The initial learning rate for Adam.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-6, type=float,
|
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--max_grad_norm", default=5.0, type=float,
|
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
|
||||||
parser.add_argument("--initializer_range", default=0.02, type=float,
|
|
||||||
help="Random initialization range.")
|
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--n_gpu", type=int, default=1,
|
)
|
||||||
help="Number of GPUs in the node.")
|
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
||||||
help="Distributed training - Local rank")
|
parser.add_argument("--seed", type=int, default=56, help="Random seed")
|
||||||
parser.add_argument("--seed", type=int, default=56,
|
|
||||||
help="Random seed")
|
|
||||||
|
|
||||||
parser.add_argument("--log_interval", type=int, default=500,
|
parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
|
||||||
help="Tensorboard logging interval.")
|
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
|
||||||
parser.add_argument("--checkpoint_interval", type=int, default=4000,
|
|
||||||
help="Checkpoint interval.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
sanity_checks(args)
|
sanity_checks(args)
|
||||||
|
|
||||||
|
|
||||||
## ARGS ##
|
## ARGS ##
|
||||||
init_gpu_params(args)
|
init_gpu_params(args)
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
if args.is_master:
|
if args.is_master:
|
||||||
if os.path.exists(args.dump_path):
|
if os.path.exists(args.dump_path):
|
||||||
if not args.force:
|
if not args.force:
|
||||||
raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it'
|
raise ValueError(
|
||||||
'Use `--force` if you want to overwrite it')
|
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
|
||||||
|
"Use `--force` if you want to overwrite it"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
shutil.rmtree(args.dump_path)
|
shutil.rmtree(args.dump_path)
|
||||||
|
|
||||||
if not os.path.exists(args.dump_path):
|
if not os.path.exists(args.dump_path):
|
||||||
os.makedirs(args.dump_path)
|
os.makedirs(args.dump_path)
|
||||||
logger.info(f'Experiment will be dumped and logged in {args.dump_path}')
|
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
||||||
|
|
||||||
|
|
||||||
### SAVE PARAMS ###
|
### SAVE PARAMS ###
|
||||||
logger.info(f'Param: {args}')
|
logger.info(f"Param: {args}")
|
||||||
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
|
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
||||||
json.dump(vars(args), f, indent=4)
|
json.dump(vars(args), f, indent=4)
|
||||||
git_log(args.dump_path)
|
git_log(args.dump_path)
|
||||||
|
|
||||||
@@ -207,58 +241,50 @@ def main():
|
|||||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||||
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
||||||
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
||||||
logger.info(f'Special tokens {special_tok_ids}')
|
logger.info(f"Special tokens {special_tok_ids}")
|
||||||
args.special_tok_ids = special_tok_ids
|
args.special_tok_ids = special_tok_ids
|
||||||
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||||
|
|
||||||
|
|
||||||
## DATA LOADER ##
|
## DATA LOADER ##
|
||||||
logger.info(f'Loading data from {args.data_file}')
|
logger.info(f"Loading data from {args.data_file}")
|
||||||
with open(args.data_file, 'rb') as fp:
|
with open(args.data_file, "rb") as fp:
|
||||||
data = pickle.load(fp)
|
data = pickle.load(fp)
|
||||||
|
|
||||||
|
|
||||||
if args.mlm:
|
if args.mlm:
|
||||||
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
|
logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
|
||||||
with open(args.token_counts, 'rb') as fp:
|
with open(args.token_counts, "rb") as fp:
|
||||||
counts = pickle.load(fp)
|
counts = pickle.load(fp)
|
||||||
|
|
||||||
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
||||||
for idx in special_tok_ids.values():
|
for idx in special_tok_ids.values():
|
||||||
token_probs[idx] = 0. # do not predict special tokens
|
token_probs[idx] = 0.0 # do not predict special tokens
|
||||||
token_probs = torch.from_numpy(token_probs)
|
token_probs = torch.from_numpy(token_probs)
|
||||||
else:
|
else:
|
||||||
token_probs = None
|
token_probs = None
|
||||||
|
|
||||||
|
|
||||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||||
logger.info(f'Data loader created.')
|
logger.info(f"Data loader created.")
|
||||||
|
|
||||||
|
|
||||||
## STUDENT ##
|
## STUDENT ##
|
||||||
logger.info(f'Loading student config from {args.student_config}')
|
logger.info(f"Loading student config from {args.student_config}")
|
||||||
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||||
stu_architecture_config.output_hidden_states = True
|
stu_architecture_config.output_hidden_states = True
|
||||||
|
|
||||||
if args.student_pretrained_weights is not None:
|
if args.student_pretrained_weights is not None:
|
||||||
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
|
logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
|
||||||
student = student_model_class.from_pretrained(args.student_pretrained_weights,
|
student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
|
||||||
config=stu_architecture_config)
|
|
||||||
else:
|
else:
|
||||||
student = student_model_class(stu_architecture_config)
|
student = student_model_class(stu_architecture_config)
|
||||||
|
|
||||||
|
|
||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
student.to(f'cuda:{args.local_rank}')
|
student.to(f"cuda:{args.local_rank}")
|
||||||
logger.info(f'Student loaded.')
|
logger.info(f"Student loaded.")
|
||||||
|
|
||||||
|
|
||||||
## TEACHER ##
|
## TEACHER ##
|
||||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
teacher.to(f'cuda:{args.local_rank}')
|
teacher.to(f"cuda:{args.local_rank}")
|
||||||
logger.info(f'Teacher loaded from {args.teacher_name}.')
|
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
||||||
|
|
||||||
|
|
||||||
## FREEZING ##
|
## FREEZING ##
|
||||||
if args.freeze_pos_embs:
|
if args.freeze_pos_embs:
|
||||||
@@ -266,7 +292,6 @@ def main():
|
|||||||
if args.freeze_token_type_embds:
|
if args.freeze_token_type_embds:
|
||||||
freeze_token_type_embeddings(student, args)
|
freeze_token_type_embeddings(student, args)
|
||||||
|
|
||||||
|
|
||||||
## SANITY CHECKS ##
|
## SANITY CHECKS ##
|
||||||
assert student.config.vocab_size == teacher.config.vocab_size
|
assert student.config.vocab_size == teacher.config.vocab_size
|
||||||
assert student.config.hidden_size == teacher.config.hidden_size
|
assert student.config.hidden_size == teacher.config.hidden_size
|
||||||
@@ -274,14 +299,11 @@ def main():
|
|||||||
if args.mlm:
|
if args.mlm:
|
||||||
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||||
|
|
||||||
|
|
||||||
## DISTILLER ##
|
## DISTILLER ##
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
distiller = Distiller(params=args,
|
distiller = Distiller(
|
||||||
dataset=train_lm_seq_dataset,
|
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
||||||
token_probs=token_probs,
|
)
|
||||||
student=student,
|
|
||||||
teacher=teacher)
|
|
||||||
distiller.train()
|
distiller.train()
|
||||||
logger.info("Let's go get some drinks.")
|
logger.info("Let's go get some drinks.")
|
||||||
|
|
||||||
|
|||||||
@@ -23,9 +23,12 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
|
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
logging.basicConfig(
|
||||||
level = logging.INFO)
|
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,12 +38,12 @@ def git_log(folder_path: str):
|
|||||||
"""
|
"""
|
||||||
repo = git.Repo(search_parent_directories=True)
|
repo = git.Repo(search_parent_directories=True)
|
||||||
repo_infos = {
|
repo_infos = {
|
||||||
'repo_id': str(repo),
|
"repo_id": str(repo),
|
||||||
'repo_sha': str(repo.head.object.hexsha),
|
"repo_sha": str(repo.head.object.hexsha),
|
||||||
'repo_branch': str(repo.active_branch)
|
"repo_branch": str(repo.active_branch),
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f:
|
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||||
json.dump(repo_infos, f, indent=4)
|
json.dump(repo_infos, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
@@ -57,21 +60,21 @@ def init_gpu_params(params):
|
|||||||
|
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
|
|
||||||
logger.info('Initializing GPUs')
|
logger.info("Initializing GPUs")
|
||||||
if params.n_gpu > 1:
|
if params.n_gpu > 1:
|
||||||
assert params.local_rank != -1
|
assert params.local_rank != -1
|
||||||
|
|
||||||
params.world_size = int(os.environ['WORLD_SIZE'])
|
params.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
params.n_gpu_per_node = int(os.environ['N_GPU_NODE'])
|
params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
|
||||||
params.global_rank = int(os.environ['RANK'])
|
params.global_rank = int(os.environ["RANK"])
|
||||||
|
|
||||||
# number of nodes / node ID
|
# number of nodes / node ID
|
||||||
params.n_nodes = params.world_size // params.n_gpu_per_node
|
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||||
params.node_id = params.global_rank // params.n_gpu_per_node
|
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||||
params.multi_gpu = True
|
params.multi_gpu = True
|
||||||
|
|
||||||
assert params.n_nodes == int(os.environ['N_NODES'])
|
assert params.n_nodes == int(os.environ["N_NODES"])
|
||||||
assert params.node_id == int(os.environ['NODE_RANK'])
|
assert params.node_id == int(os.environ["NODE_RANK"])
|
||||||
|
|
||||||
# local job (single GPU)
|
# local job (single GPU)
|
||||||
else:
|
else:
|
||||||
@@ -114,8 +117,7 @@ def init_gpu_params(params):
|
|||||||
if params.multi_gpu:
|
if params.multi_gpu:
|
||||||
logger.info("Initializing PyTorch distributed")
|
logger.info("Initializing PyTorch distributed")
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
init_method='env://',
|
init_method="env://", backend="nccl",
|
||||||
backend='nccl',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -40,29 +40,49 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
|
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME,
|
from transformers import (
|
||||||
BertConfig, BertModel, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
RobertaConfig, RobertaModel, RobertaTokenizer,
|
BertConfig,
|
||||||
XLMConfig, XLMModel, XLMTokenizer,
|
BertModel,
|
||||||
XLNetConfig, XLNetModel, XLNetTokenizer,
|
BertTokenizer,
|
||||||
DistilBertConfig, DistilBertModel, DistilBertTokenizer,
|
RobertaConfig,
|
||||||
AlbertConfig, AlbertModel, AlbertTokenizer,
|
RobertaModel,
|
||||||
MMBTForClassification, MMBTConfig)
|
RobertaTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMModel,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
|
XLNetModel,
|
||||||
|
XLNetTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertModel,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
AlbertConfig,
|
||||||
|
AlbertModel,
|
||||||
|
AlbertTokenizer,
|
||||||
|
MMBTForClassification,
|
||||||
|
MMBTConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
|
ALL_MODELS = sum(
|
||||||
RobertaConfig, DistilBertConfig)), ())
|
(
|
||||||
|
tuple(conf.pretrained_config_archive_map.keys())
|
||||||
|
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||||
|
),
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertModel, BertTokenizer),
|
"bert": (BertConfig, BertModel, BertTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetModel, XLNetTokenizer),
|
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMModel, XLMTokenizer),
|
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer),
|
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||||
'albert': (AlbertConfig, AlbertModel, AlbertTokenizer)
|
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -81,10 +101,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
|
|
||||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler,
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
sampler=train_sampler,
|
||||||
batch_size=args.train_batch_size,
|
batch_size=args.train_batch_size,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
num_workers=args.num_workers)
|
num_workers=args.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
t_total = args.max_steps
|
t_total = args.max_steps
|
||||||
@@ -93,14 +116,19 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
|
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -114,17 +142,21 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -140,11 +172,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
labels = batch[5]
|
labels = batch[5]
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'input_modal': batch[2],
|
"input_ids": batch[0],
|
||||||
'attention_mask': batch[1],
|
"input_modal": batch[2],
|
||||||
'modal_start_tokens': batch[3],
|
"attention_mask": batch[1],
|
||||||
'modal_end_tokens': batch[4]}
|
"modal_start_tokens": batch[3],
|
||||||
|
"modal_end_tokens": batch[4],
|
||||||
|
}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
loss = criterion(logits, labels)
|
loss = criterion(logits, labels)
|
||||||
@@ -174,30 +208,34 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
logs = {}
|
logs = {}
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer, criterion)
|
results = evaluate(args, model, tokenizer, criterion)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
eval_key = 'eval_{}'.format(key)
|
eval_key = "eval_{}".format(key)
|
||||||
logs[eval_key] = value
|
logs[eval_key] = value
|
||||||
|
|
||||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||||
learning_rate_scalar = scheduler.get_lr()[0]
|
learning_rate_scalar = scheduler.get_lr()[0]
|
||||||
logs['learning_rate'] = learning_rate_scalar
|
logs["learning_rate"] = learning_rate_scalar
|
||||||
logs['loss'] = loss_scalar
|
logs["loss"] = loss_scalar
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
for key, value in logs.items():
|
for key, value in logs.items():
|
||||||
tb_writer.add_scalar(key, value, global_step)
|
tb_writer.add_scalar(key, value, global_step)
|
||||||
print(json.dumps({**logs, **{'step': global_step}}))
|
print(json.dumps({**logs, **{"step": global_step}}))
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
|
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -209,8 +247,8 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
|
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
results = evaluate(args, model, tokenizer, criterion)
|
results = evaluate(args, model, tokenizer, criterion)
|
||||||
if results['micro_f1'] > best_f1:
|
if results["micro_f1"] > best_f1:
|
||||||
best_f1 = results['micro_f1']
|
best_f1 = results["micro_f1"]
|
||||||
n_no_improve = 0
|
n_no_improve = 0
|
||||||
else:
|
else:
|
||||||
n_no_improve += 1
|
n_no_improve += 1
|
||||||
@@ -236,7 +274,9 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
|||||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
eval_sampler = SequentialSampler(eval_dataset)
|
eval_sampler = SequentialSampler(eval_dataset)
|
||||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn)
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
@@ -257,11 +297,13 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
labels = batch[5]
|
labels = batch[5]
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'input_modal': batch[2],
|
"input_ids": batch[0],
|
||||||
'attention_mask': batch[1],
|
"input_modal": batch[2],
|
||||||
'modal_start_tokens': batch[3],
|
"attention_mask": batch[1],
|
||||||
'modal_end_tokens': batch[4]}
|
"modal_start_tokens": batch[3],
|
||||||
|
"modal_end_tokens": batch[4],
|
||||||
|
}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
tmp_eval_loss = criterion(logits, labels)
|
tmp_eval_loss = criterion(logits, labels)
|
||||||
@@ -278,7 +320,7 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
|||||||
result = {
|
result = {
|
||||||
"loss": eval_loss,
|
"loss": eval_loss,
|
||||||
"macro_f1": f1_score(out_label_ids, preds, average="macro"),
|
"macro_f1": f1_score(out_label_ids, preds, average="macro"),
|
||||||
"micro_f1": f1_score(out_label_ids, preds, average="micro")
|
"micro_f1": f1_score(out_label_ids, preds, average="micro"),
|
||||||
}
|
}
|
||||||
|
|
||||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||||
@@ -303,94 +345,147 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.")
|
"--data_dir",
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
type=str,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
required=True,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
)
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
"--tokenizer_name",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
default="",
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.")
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
parser.add_argument("--num_image_embeds", default=1, type=int,
|
)
|
||||||
help="Number of Image Embeddings from the Image Encoder")
|
parser.add_argument(
|
||||||
parser.add_argument("--do_train", action='store_true',
|
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
|
||||||
help="Whether to run training.")
|
)
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
parser.add_argument(
|
||||||
help="Rul evaluation during training at each logging step.")
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
)
|
||||||
help="Set this flag if you are using an uncased model.")
|
parser.add_argument(
|
||||||
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument(
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
type=int,
|
||||||
help="The initial learning rate for Adam.")
|
default=1,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
help="Weight deay if we apply some.")
|
)
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--patience", default=5, type=int,
|
)
|
||||||
help="Patience for Early Stopping.")
|
parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
parser.add_argument(
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
"--max_steps",
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
default=-1,
|
||||||
help="Linear warmup over warmup_steps.")
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Avoid using CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--num_workers', type=int, default=8,
|
parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")
|
||||||
help="number of worker threads for dataloading")
|
parser.add_argument(
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
help="Overwrite the content of the output directory")
|
)
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the cached training and evaluation sets")
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
)
|
||||||
help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
)
|
||||||
help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -402,17 +497,25 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
|
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -426,13 +529,17 @@ def main():
|
|||||||
num_labels = len(labels)
|
num_labels = len(labels)
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
transformer_config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
transformer_config = config_class.from_pretrained(
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
transformer = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
config=transformer_config,
|
transformer = model_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||||
|
)
|
||||||
img_encoder = ImageEncoder(args)
|
img_encoder = ImageEncoder(args)
|
||||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||||
model = MMBTForClassification(config, transformer, img_encoder)
|
model = MMBTForClassification(config, transformer, img_encoder)
|
||||||
@@ -449,12 +556,13 @@ def main():
|
|||||||
train_dataset = load_examples(args, tokenizer, evaluate=False)
|
train_dataset = load_examples(args, tokenizer, evaluate=False)
|
||||||
label_frequences = train_dataset.get_label_frequencies()
|
label_frequences = train_dataset.get_label_frequencies()
|
||||||
label_frequences = [label_frequences[l] for l in labels]
|
label_frequences = [label_frequences[l] for l in labels]
|
||||||
label_weights = (torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)) ** -1
|
label_weights = (
|
||||||
|
torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)
|
||||||
|
) ** -1
|
||||||
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
|
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -464,12 +572,14 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
|
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = MMBTForClassification(config, transformer, img_encoder)
|
model = MMBTForClassification(config, transformer, img_encoder)
|
||||||
@@ -477,24 +587,25 @@ def main():
|
|||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
model = MMBTForClassification(config, transformer, img_encoder)
|
model = MMBTForClassification(config, transformer, img_encoder)
|
||||||
model.load_state_dict(torch.load(checkpoint))
|
model.load_state_dict(torch.load(checkpoint))
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
|
result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -25,17 +25,7 @@ import torchvision
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
POOLING_BREAKDOWN = {
|
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
||||||
1: (1, 1),
|
|
||||||
2: (2, 1),
|
|
||||||
3: (3, 1),
|
|
||||||
4: (2, 2),
|
|
||||||
5: (5, 1),
|
|
||||||
6: (3, 2),
|
|
||||||
7: (7, 1),
|
|
||||||
8: (4, 2),
|
|
||||||
9: (3, 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ImageEncoder(nn.Module):
|
class ImageEncoder(nn.Module):
|
||||||
@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module):
|
|||||||
return out # BxNx2048
|
return out # BxNx2048
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JsonlDataset(Dataset):
|
class JsonlDataset(Dataset):
|
||||||
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
||||||
self.data = [json.loads(l) for l in open(data_path)]
|
self.data = [json.loads(l) for l in open(data_path)]
|
||||||
@@ -80,8 +69,13 @@ class JsonlDataset(Dataset):
|
|||||||
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
||||||
image = self.transforms(image)
|
image = self.transforms(image)
|
||||||
|
|
||||||
return {"image_start_token": start_token, "image_end_token": end_token,
|
return {
|
||||||
"sentence": sentence, "image": image, "label": label}
|
"image_start_token": start_token,
|
||||||
|
"image_end_token": end_token,
|
||||||
|
"sentence": sentence,
|
||||||
|
"image": image,
|
||||||
|
"label": label,
|
||||||
|
}
|
||||||
|
|
||||||
def get_label_frequencies(self):
|
def get_label_frequencies(self):
|
||||||
label_freqs = Counter()
|
label_freqs = Counter()
|
||||||
@@ -110,10 +104,31 @@ def collate_fn(batch):
|
|||||||
|
|
||||||
|
|
||||||
def get_mmimdb_labels():
|
def get_mmimdb_labels():
|
||||||
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance',
|
return [
|
||||||
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure',
|
"Crime",
|
||||||
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music',
|
"Drama",
|
||||||
'Musical', 'Animation', 'Biography', 'Film-Noir']
|
"Thriller",
|
||||||
|
"Action",
|
||||||
|
"Comedy",
|
||||||
|
"Romance",
|
||||||
|
"Documentary",
|
||||||
|
"Short",
|
||||||
|
"Mystery",
|
||||||
|
"History",
|
||||||
|
"Family",
|
||||||
|
"Adventure",
|
||||||
|
"Fantasy",
|
||||||
|
"Sci-Fi",
|
||||||
|
"Western",
|
||||||
|
"Horror",
|
||||||
|
"Sport",
|
||||||
|
"War",
|
||||||
|
"Music",
|
||||||
|
"Musical",
|
||||||
|
"Animation",
|
||||||
|
"Biography",
|
||||||
|
"Film-Noir",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_image_transforms():
|
def get_image_transforms():
|
||||||
@@ -122,9 +137,6 @@ def get_image_transforms():
|
|||||||
transforms.Resize(256),
|
transforms.Resize(256),
|
||||||
transforms.CenterCrop(224),
|
transforms.CenterCrop(224),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
|
||||||
mean=[0.46777044, 0.44531429, 0.40661017],
|
|
||||||
std=[0.12221994, 0.12145835, 0.14380469],
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class ClassificationHead(torch.nn.Module):
|
class ClassificationHead(torch.nn.Module):
|
||||||
"""Classification Head for transformer encoders"""
|
"""Classification Head for transformer encoders"""
|
||||||
|
|
||||||
|
|||||||
@@ -46,13 +46,13 @@ SMALL_CONST = 1e-15
|
|||||||
BIG_CONST = 1e10
|
BIG_CONST = 1e10
|
||||||
|
|
||||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||||
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
"legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
||||||
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
"military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
||||||
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
"politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
||||||
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
"religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
||||||
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
"science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
||||||
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
"space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
||||||
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
"technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
||||||
}
|
}
|
||||||
|
|
||||||
DISCRIMINATOR_MODELS_PARAMS = {
|
DISCRIMINATOR_MODELS_PARAMS = {
|
||||||
@@ -75,10 +75,10 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
|
def to_var(x, requires_grad=False, volatile=False, device="cuda"):
|
||||||
if torch.cuda.is_available() and device == 'cuda':
|
if torch.cuda.is_available() and device == "cuda":
|
||||||
x = x.cuda()
|
x = x.cuda()
|
||||||
elif device != 'cuda':
|
elif device != "cuda":
|
||||||
x = x.to(device)
|
x = x.to(device)
|
||||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
||||||
|
|
||||||
@@ -95,11 +95,8 @@ def top_k_filter(logits, k, probs=False):
|
|||||||
values = torch.topk(logits, k)[0]
|
values = torch.topk(logits, k)[0]
|
||||||
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
||||||
if probs:
|
if probs:
|
||||||
return torch.where(logits < batch_mins,
|
return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits)
|
||||||
torch.ones_like(logits) * 0.0, logits)
|
return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits)
|
||||||
return torch.where(logits < batch_mins,
|
|
||||||
torch.ones_like(logits) * -BIG_CONST,
|
|
||||||
logits)
|
|
||||||
|
|
||||||
|
|
||||||
def perturb_past(
|
def perturb_past(
|
||||||
@@ -121,23 +118,16 @@ def perturb_past(
|
|||||||
decay=False,
|
decay=False,
|
||||||
gamma=1.5,
|
gamma=1.5,
|
||||||
kl_scale=0.01,
|
kl_scale=0.01,
|
||||||
device='cuda',
|
device="cuda",
|
||||||
):
|
):
|
||||||
# Generate inital perturbed past
|
# Generate inital perturbed past
|
||||||
grad_accumulator = [
|
grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]
|
||||||
(np.zeros(p.shape).astype("float32"))
|
|
||||||
for p in past
|
|
||||||
]
|
|
||||||
|
|
||||||
if accumulated_hidden is None:
|
if accumulated_hidden is None:
|
||||||
accumulated_hidden = 0
|
accumulated_hidden = 0
|
||||||
|
|
||||||
if decay:
|
if decay:
|
||||||
decay_mask = torch.arange(
|
decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
|
||||||
0.,
|
|
||||||
1.0 + SMALL_CONST,
|
|
||||||
1.0 / (window_length)
|
|
||||||
)[1:]
|
|
||||||
else:
|
else:
|
||||||
decay_mask = 1.0
|
decay_mask = 1.0
|
||||||
|
|
||||||
@@ -146,26 +136,17 @@ def perturb_past(
|
|||||||
_, _, _, curr_length, _ = past[0].shape
|
_, _, _, curr_length, _ = past[0].shape
|
||||||
|
|
||||||
if curr_length > window_length and window_length > 0:
|
if curr_length > window_length and window_length > 0:
|
||||||
ones_key_val_shape = (
|
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
|
||||||
tuple(past[0].shape[:-2])
|
|
||||||
+ tuple([window_length])
|
|
||||||
+ tuple(past[0].shape[-1:])
|
|
||||||
)
|
|
||||||
|
|
||||||
zeros_key_val_shape = (
|
zeros_key_val_shape = (
|
||||||
tuple(past[0].shape[:-2])
|
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
|
||||||
+ tuple([curr_length - window_length])
|
|
||||||
+ tuple(past[0].shape[-1:])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ones_mask = torch.ones(ones_key_val_shape)
|
ones_mask = torch.ones(ones_key_val_shape)
|
||||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||||
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
||||||
|
|
||||||
window_mask = torch.cat(
|
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
|
||||||
(ones_mask, torch.zeros(zeros_key_val_shape)),
|
|
||||||
dim=-2
|
|
||||||
).to(device)
|
|
||||||
else:
|
else:
|
||||||
window_mask = torch.ones_like(past[0]).to(device)
|
window_mask = torch.ones_like(past[0]).to(device)
|
||||||
|
|
||||||
@@ -175,8 +156,7 @@ def perturb_past(
|
|||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
print("Iteration ", i + 1)
|
print("Iteration ", i + 1)
|
||||||
curr_perturbation = [
|
curr_perturbation = [
|
||||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
|
||||||
for p_ in grad_accumulator
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Compute hidden using perturbed past
|
# Compute hidden using perturbed past
|
||||||
@@ -184,10 +164,7 @@ def perturb_past(
|
|||||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||||
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
||||||
hidden = all_hidden[-1]
|
hidden = all_hidden[-1]
|
||||||
new_accumulated_hidden = accumulated_hidden + torch.sum(
|
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||||
hidden,
|
|
||||||
dim=1
|
|
||||||
).detach()
|
|
||||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||||
logits = all_logits[:, -1, :]
|
logits = all_logits[:, -1, :]
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
@@ -210,20 +187,13 @@ def perturb_past(
|
|||||||
wte = model.resize_token_embeddings()
|
wte = model.resize_token_embeddings()
|
||||||
for _ in range(horizon_length):
|
for _ in range(horizon_length):
|
||||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||||
_, curr_unpert_past, curr_all_hidden = model(
|
_, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||||
past=curr_unpert_past,
|
|
||||||
inputs_embeds=inputs_embeds
|
|
||||||
)
|
|
||||||
curr_hidden = curr_all_hidden[-1]
|
curr_hidden = curr_all_hidden[-1]
|
||||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||||
curr_hidden, dim=1)
|
|
||||||
|
|
||||||
prediction = classifier(new_accumulated_hidden /
|
prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length))
|
||||||
(curr_length + 1 + horizon_length))
|
|
||||||
|
|
||||||
label = torch.tensor(prediction.shape[0] * [class_label],
|
label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long)
|
||||||
device=device,
|
|
||||||
dtype=torch.long)
|
|
||||||
discrim_loss = ce_loss(prediction, label)
|
discrim_loss = ce_loss(prediction, label)
|
||||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||||
loss += discrim_loss
|
loss += discrim_loss
|
||||||
@@ -232,21 +202,15 @@ def perturb_past(
|
|||||||
kl_loss = 0.0
|
kl_loss = 0.0
|
||||||
if kl_scale > 0.0:
|
if kl_scale > 0.0:
|
||||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
unpert_probs = (
|
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||||
unpert_probs + SMALL_CONST *
|
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||||
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
|
||||||
)
|
|
||||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
|
|
||||||
device).detach()
|
|
||||||
corrected_probs = probs + correction.detach()
|
corrected_probs = probs + correction.detach()
|
||||||
kl_loss = kl_scale * (
|
kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum())
|
||||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
print(" kl_loss", kl_loss.data.cpu().numpy())
|
||||||
)
|
|
||||||
print(' kl_loss', kl_loss.data.cpu().numpy())
|
|
||||||
loss += kl_loss
|
loss += kl_loss
|
||||||
|
|
||||||
loss_per_iter.append(loss.data.cpu().numpy())
|
loss_per_iter.append(loss.data.cpu().numpy())
|
||||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())
|
||||||
|
|
||||||
# compute gradients
|
# compute gradients
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@@ -259,15 +223,12 @@ def perturb_past(
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
grad_norms = [
|
grad_norms = [
|
||||||
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
|
(torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation)
|
||||||
for index, p_ in enumerate(curr_perturbation)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# normalize gradients
|
# normalize gradients
|
||||||
grad = [
|
grad = [
|
||||||
-stepsize *
|
-stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
||||||
(p_.grad * window_mask / grad_norms[
|
|
||||||
index] ** gamma).data.cpu().numpy()
|
|
||||||
for index, p_ in enumerate(curr_perturbation)
|
for index, p_ in enumerate(curr_perturbation)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -285,36 +246,27 @@ def perturb_past(
|
|||||||
past = new_past
|
past = new_past
|
||||||
|
|
||||||
# apply the accumulated perturbations to the past
|
# apply the accumulated perturbations to the past
|
||||||
grad_accumulator = [
|
grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator]
|
||||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
|
||||||
for p_ in grad_accumulator
|
|
||||||
]
|
|
||||||
pert_past = list(map(add, past, grad_accumulator))
|
pert_past = list(map(add, past, grad_accumulator))
|
||||||
|
|
||||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(
|
def get_classifier(
|
||||||
name: Optional[str], class_label: Union[str, int],
|
name: Optional[str], class_label: Union[str, int], device: str
|
||||||
device: str
|
|
||||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||||
if name is None:
|
if name is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
params = DISCRIMINATOR_MODELS_PARAMS[name]
|
params = DISCRIMINATOR_MODELS_PARAMS[name]
|
||||||
classifier = ClassificationHead(
|
classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device)
|
||||||
class_size=params['class_size'],
|
|
||||||
embed_size=params['embed_size']
|
|
||||||
).to(device)
|
|
||||||
if "url" in params:
|
if "url" in params:
|
||||||
resolved_archive_file = cached_path(params["url"])
|
resolved_archive_file = cached_path(params["url"])
|
||||||
elif "path" in params:
|
elif "path" in params:
|
||||||
resolved_archive_file = params["path"]
|
resolved_archive_file = params["path"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either url or path have to be specified "
|
raise ValueError("Either url or path have to be specified " "in the discriminator model parameters")
|
||||||
"in the discriminator model parameters")
|
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
|
||||||
classifier.load_state_dict(
|
|
||||||
torch.load(resolved_archive_file, map_location=device))
|
|
||||||
classifier.eval()
|
classifier.eval()
|
||||||
|
|
||||||
if isinstance(class_label, str):
|
if isinstance(class_label, str):
|
||||||
@@ -341,8 +293,7 @@ def get_classifier(
|
|||||||
return classifier, label_id
|
return classifier, label_id
|
||||||
|
|
||||||
|
|
||||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
|
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]:
|
||||||
List[List[List[int]]]:
|
|
||||||
bow_indices = []
|
bow_indices = []
|
||||||
for id_or_path in bag_of_words_ids_or_paths:
|
for id_or_path in bag_of_words_ids_or_paths:
|
||||||
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
||||||
@@ -351,13 +302,11 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) ->
|
|||||||
filepath = id_or_path
|
filepath = id_or_path
|
||||||
with open(filepath, "r") as f:
|
with open(filepath, "r") as f:
|
||||||
words = f.read().strip().split("\n")
|
words = f.read().strip().split("\n")
|
||||||
bow_indices.append(
|
bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words])
|
||||||
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
|
|
||||||
words])
|
|
||||||
return bow_indices
|
return bow_indices
|
||||||
|
|
||||||
|
|
||||||
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
|
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
|
||||||
if bow_indices is None:
|
if bow_indices is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -396,16 +345,11 @@ def full_text_generation(
|
|||||||
kl_scale=0.01,
|
kl_scale=0.01,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
classifier, class_id = get_classifier(
|
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||||
discrim,
|
|
||||||
class_label,
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
bow_indices = []
|
bow_indices = []
|
||||||
if bag_of_words:
|
if bag_of_words:
|
||||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
|
||||||
tokenizer)
|
|
||||||
|
|
||||||
if bag_of_words and classifier:
|
if bag_of_words and classifier:
|
||||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||||
@@ -423,15 +367,9 @@ def full_text_generation(
|
|||||||
raise Exception("Specify either a bag of words or a discriminator")
|
raise Exception("Specify either a bag of words or a discriminator")
|
||||||
|
|
||||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||||
model=model,
|
model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
|
||||||
tokenizer=tokenizer,
|
|
||||||
context=context,
|
|
||||||
device=device,
|
|
||||||
length=length,
|
|
||||||
sample=sample,
|
|
||||||
perturb=False
|
|
||||||
)
|
)
|
||||||
if device == 'cuda':
|
if device == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
pert_gen_tok_texts = []
|
pert_gen_tok_texts = []
|
||||||
@@ -468,7 +406,7 @@ def full_text_generation(
|
|||||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||||
losses_in_time.append(loss_in_time)
|
losses_in_time.append(loss_in_time)
|
||||||
|
|
||||||
if device == 'cuda':
|
if device == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||||
@@ -507,8 +445,7 @@ def generate_text_pplm(
|
|||||||
output_so_far = context_t
|
output_so_far = context_t
|
||||||
|
|
||||||
# collect one hot vectors for bags of words
|
# collect one hot vectors for bags of words
|
||||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)
|
||||||
device)
|
|
||||||
|
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
last = None
|
last = None
|
||||||
@@ -575,13 +512,9 @@ def generate_text_pplm(
|
|||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||||
label = torch.tensor([class_label], device=device,
|
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
||||||
dtype=torch.long)
|
|
||||||
unpert_discrim_loss = ce_loss(prediction, label)
|
unpert_discrim_loss = ce_loss(prediction, label)
|
||||||
print(
|
print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
|
||||||
"unperturbed discrim loss",
|
|
||||||
unpert_discrim_loss.data.cpu().numpy()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
unpert_discrim_loss = 0
|
unpert_discrim_loss = 0
|
||||||
|
|
||||||
@@ -590,10 +523,8 @@ def generate_text_pplm(
|
|||||||
|
|
||||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
|
|
||||||
pert_probs = ((pert_probs ** gm_scale) * (
|
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||||
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||||
pert_probs = top_k_filter(pert_probs, k=top_k,
|
|
||||||
probs=True) # + SMALL_CONST
|
|
||||||
|
|
||||||
# rescale
|
# rescale
|
||||||
if torch.sum(pert_probs) <= 1:
|
if torch.sum(pert_probs) <= 1:
|
||||||
@@ -611,10 +542,7 @@ def generate_text_pplm(
|
|||||||
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
||||||
|
|
||||||
# update context/output_so_far appending the new token
|
# update context/output_so_far appending the new token
|
||||||
output_so_far = (
|
output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
|
||||||
last if output_so_far is None
|
|
||||||
else torch.cat((output_so_far, last), dim=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(tokenizer.decode(output_so_far.tolist()[0]))
|
print(tokenizer.decode(output_so_far.tolist()[0]))
|
||||||
|
|
||||||
@@ -623,16 +551,14 @@ def generate_text_pplm(
|
|||||||
|
|
||||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||||
if discrim_weights is None:
|
if discrim_weights is None:
|
||||||
raise ValueError('When using a generic discriminator, '
|
raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified")
|
||||||
'discrim_weights need to be specified')
|
|
||||||
if discrim_meta is None:
|
if discrim_meta is None:
|
||||||
raise ValueError('When using a generic discriminator, '
|
raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified")
|
||||||
'discrim_meta need to be specified')
|
|
||||||
|
|
||||||
with open(discrim_meta, 'r') as discrim_meta_file:
|
with open(discrim_meta, "r") as discrim_meta_file:
|
||||||
meta = json.load(discrim_meta_file)
|
meta = json.load(discrim_meta_file)
|
||||||
meta['path'] = discrim_weights
|
meta["path"] = discrim_weights
|
||||||
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
DISCRIMINATOR_MODELS_PARAMS["generic"] = meta
|
||||||
|
|
||||||
|
|
||||||
def run_pplm_example(
|
def run_pplm_example(
|
||||||
@@ -660,7 +586,7 @@ def run_pplm_example(
|
|||||||
kl_scale=0.01,
|
kl_scale=0.01,
|
||||||
seed=0,
|
seed=0,
|
||||||
no_cuda=False,
|
no_cuda=False,
|
||||||
colorama=False
|
colorama=False,
|
||||||
):
|
):
|
||||||
# set Random seed
|
# set Random seed
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@@ -669,21 +595,15 @@ def run_pplm_example(
|
|||||||
# set the device
|
# set the device
|
||||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||||
|
|
||||||
if discrim == 'generic':
|
if discrim == "generic":
|
||||||
set_generic_model_params(discrim_weights, discrim_meta)
|
set_generic_model_params(discrim_weights, discrim_meta)
|
||||||
|
|
||||||
if discrim is not None:
|
if discrim is not None:
|
||||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
|
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
|
||||||
"pretrained_model"
|
print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model))
|
||||||
]
|
|
||||||
print("discrim = {}, pretrained_model set "
|
|
||||||
"to discriminator's = {}".format(discrim, pretrained_model))
|
|
||||||
|
|
||||||
# load pretrained model
|
# load pretrained model
|
||||||
model = GPT2LMHeadModel.from_pretrained(
|
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
|
||||||
pretrained_model,
|
|
||||||
output_hidden_states=True
|
|
||||||
)
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@@ -696,9 +616,7 @@ def run_pplm_example(
|
|||||||
|
|
||||||
# figure out conditioning text
|
# figure out conditioning text
|
||||||
if uncond:
|
if uncond:
|
||||||
tokenized_cond_text = tokenizer.encode(
|
tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
|
||||||
[tokenizer.bos_token]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raw_text = cond_text
|
raw_text = cond_text
|
||||||
while not raw_text:
|
while not raw_text:
|
||||||
@@ -750,8 +668,7 @@ def run_pplm_example(
|
|||||||
|
|
||||||
bow_word_ids = set()
|
bow_word_ids = set()
|
||||||
if bag_of_words and colorama:
|
if bag_of_words and colorama:
|
||||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
|
||||||
tokenizer)
|
|
||||||
for single_bow_list in bow_indices:
|
for single_bow_list in bow_indices:
|
||||||
# filtering all words in the list composed of more than 1 token
|
# filtering all words in the list composed of more than 1 token
|
||||||
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||||
@@ -765,13 +682,11 @@ def run_pplm_example(
|
|||||||
if colorama:
|
if colorama:
|
||||||
import colorama
|
import colorama
|
||||||
|
|
||||||
pert_gen_text = ''
|
pert_gen_text = ""
|
||||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||||
if word_id in bow_word_ids:
|
if word_id in bow_word_ids:
|
||||||
pert_gen_text += '{}{}{}'.format(
|
pert_gen_text += "{}{}{}".format(
|
||||||
colorama.Fore.RED,
|
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL
|
||||||
tokenizer.decode([word_id]),
|
|
||||||
colorama.Style.RESET_ALL
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pert_gen_text += tokenizer.decode([word_id])
|
pert_gen_text += tokenizer.decode([word_id])
|
||||||
@@ -785,14 +700,12 @@ def run_pplm_example(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# keep the prefix, perturbed seq, original seq for each index
|
# keep the prefix, perturbed seq, original seq for each index
|
||||||
generated_texts.append(
|
generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))
|
||||||
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretrained_model",
|
"--pretrained_model",
|
||||||
@@ -801,19 +714,10 @@ if __name__ == '__main__':
|
|||||||
default="gpt2-medium",
|
default="gpt2-medium",
|
||||||
help="pretrained model name or path to local checkpoint",
|
help="pretrained model name or path to local checkpoint",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
|
||||||
|
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cond_text", type=str, default="The lake",
|
"--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
|
||||||
help="Prefix texts to condition on"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--uncond", action="store_true",
|
|
||||||
help="Generate from end-of-text as prefix"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_samples",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of samples to generate from the modified latents",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bag_of_words",
|
"--bag_of_words",
|
||||||
@@ -832,48 +736,36 @@ if __name__ == '__main__':
|
|||||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||||
help="Discriminator to use",
|
help="Discriminator to use",
|
||||||
)
|
)
|
||||||
parser.add_argument('--discrim_weights', type=str, default=None,
|
parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
|
||||||
help='Weights for the generic discriminator')
|
|
||||||
parser.add_argument('--discrim_meta', type=str, default=None,
|
|
||||||
help='Meta information for the generic discriminator')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--class_label",
|
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator"
|
||||||
type=int,
|
)
|
||||||
default=-1,
|
parser.add_argument(
|
||||||
help="Class label used for the discriminator",
|
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
||||||
)
|
)
|
||||||
parser.add_argument("--length", type=int, default=100)
|
parser.add_argument("--length", type=int, default=100)
|
||||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_k", type=int, default=10)
|
parser.add_argument("--top_k", type=int, default=10)
|
||||||
parser.add_argument(
|
parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix")
|
||||||
"--sample", action="store_true",
|
|
||||||
help="Generate from end-of-text as prefix"
|
|
||||||
)
|
|
||||||
parser.add_argument("--num_iterations", type=int, default=3)
|
parser.add_argument("--num_iterations", type=int, default=3)
|
||||||
parser.add_argument("--grad_length", type=int, default=10000)
|
parser.add_argument("--grad_length", type=int, default=10000)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--window_length",
|
"--window_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Length of past which is being optimized; "
|
help="Length of past which is being optimized; " "0 corresponds to infinite window length",
|
||||||
"0 corresponds to infinite window length",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--horizon_length",
|
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Length of future to optimize over",
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--decay", action="store_true",
|
parser.add_argument("--decay", action="store_true", help="whether to decay or not")
|
||||||
help="whether to decay or not")
|
|
||||||
parser.add_argument("--gamma", type=float, default=1.5)
|
parser.add_argument("--gamma", type=float, default=1.5)
|
||||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||||
parser.add_argument("--colorama", action="store_true",
|
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||||
help="colors keywords")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_pplm_example(**vars(args))
|
run_pplm_example(**vars(args))
|
||||||
|
|||||||
@@ -42,26 +42,15 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
|
|||||||
max_length_seq = 100
|
max_length_seq = 100
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Discriminator(torch.nn.Module):
|
class Discriminator(torch.nn.Module):
|
||||||
"""Transformer encoder followed by a Classification Head"""
|
"""Transformer encoder followed by a Classification Head"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
|
||||||
self,
|
|
||||||
class_size,
|
|
||||||
pretrained_model="gpt2-medium",
|
|
||||||
cached_mode=False,
|
|
||||||
device='cpu'
|
|
||||||
):
|
|
||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||||
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
||||||
self.embed_size = self.encoder.transformer.config.hidden_size
|
self.embed_size = self.encoder.transformer.config.hidden_size
|
||||||
self.classifier_head = ClassificationHead(
|
self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
|
||||||
class_size=class_size,
|
|
||||||
embed_size=self.embed_size
|
|
||||||
)
|
|
||||||
self.cached_mode = cached_mode
|
self.cached_mode = cached_mode
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@@ -74,14 +63,10 @@ class Discriminator(torch.nn.Module):
|
|||||||
self.classifier_head.train()
|
self.classifier_head.train()
|
||||||
|
|
||||||
def avg_representation(self, x):
|
def avg_representation(self, x):
|
||||||
mask = x.ne(0).unsqueeze(2).repeat(
|
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
|
||||||
1, 1, self.embed_size
|
|
||||||
).float().to(self.device).detach()
|
|
||||||
hidden, _ = self.encoder.transformer(x)
|
hidden, _ = self.encoder.transformer(x)
|
||||||
masked_hidden = hidden * mask
|
masked_hidden = hidden * mask
|
||||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
|
||||||
torch.sum(mask, dim=1).detach() + EPSILON
|
|
||||||
)
|
|
||||||
return avg_hidden
|
return avg_hidden
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -117,10 +102,7 @@ def collate_fn(data):
|
|||||||
def pad_sequences(sequences):
|
def pad_sequences(sequences):
|
||||||
lengths = [len(seq) for seq in sequences]
|
lengths = [len(seq) for seq in sequences]
|
||||||
|
|
||||||
padded_sequences = torch.zeros(
|
padded_sequences = torch.zeros(len(sequences), max(lengths)).long() # padding value = 0
|
||||||
len(sequences),
|
|
||||||
max(lengths)
|
|
||||||
).long() # padding value = 0
|
|
||||||
|
|
||||||
for i, seq in enumerate(sequences):
|
for i, seq in enumerate(sequences):
|
||||||
end = lengths[i]
|
end = lengths[i]
|
||||||
@@ -149,8 +131,7 @@ def cached_collate_fn(data):
|
|||||||
return x_batch, y_batch
|
return x_batch, y_batch
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(data_loader, discriminator, optimizer,
|
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
|
||||||
epoch=0, log_interval=10, device='cpu'):
|
|
||||||
samples_so_far = 0
|
samples_so_far = 0
|
||||||
discriminator.train_custom()
|
discriminator.train_custom()
|
||||||
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||||
@@ -169,13 +150,15 @@ def train_epoch(data_loader, discriminator, optimizer,
|
|||||||
print(
|
print(
|
||||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
samples_so_far, len(data_loader.dataset),
|
samples_so_far,
|
||||||
100 * samples_so_far / len(data_loader.dataset), loss.item()
|
len(data_loader.dataset),
|
||||||
|
100 * samples_so_far / len(data_loader.dataset),
|
||||||
|
loss.item(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_performance(data_loader, discriminator, device='cpu'):
|
def evaluate_performance(data_loader, discriminator, device="cpu"):
|
||||||
discriminator.eval()
|
discriminator.eval()
|
||||||
test_loss = 0
|
test_loss = 0
|
||||||
correct = 0
|
correct = 0
|
||||||
@@ -194,13 +177,12 @@ def evaluate_performance(data_loader, discriminator, device='cpu'):
|
|||||||
print(
|
print(
|
||||||
"Performance on test set: "
|
"Performance on test set: "
|
||||||
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
||||||
test_loss, correct, len(data_loader.dataset),
|
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
|
||||||
100. * correct / len(data_loader.dataset)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
def predict(input_sentence, model, classes, cached=False, device="cpu"):
|
||||||
input_t = model.tokenizer.encode(input_sentence)
|
input_t = model.tokenizer.encode(input_sentence)
|
||||||
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||||
if cached:
|
if cached:
|
||||||
@@ -208,17 +190,14 @@ def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
|||||||
|
|
||||||
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
||||||
print("Input sentence:", input_sentence)
|
print("Input sentence:", input_sentence)
|
||||||
print("Predictions:", ", ".join(
|
print(
|
||||||
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
|
"Predictions:",
|
||||||
zip(classes, log_probs)
|
", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
|
||||||
))
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_cached_data_loader(dataset, batch_size, discriminator,
|
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
|
||||||
shuffle=False, device='cpu'):
|
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
|
||||||
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
ys = []
|
ys = []
|
||||||
@@ -231,50 +210,44 @@ def get_cached_data_loader(dataset, batch_size, discriminator,
|
|||||||
ys += y.cpu().numpy().tolist()
|
ys += y.cpu().numpy().tolist()
|
||||||
|
|
||||||
data_loader = torch.utils.data.DataLoader(
|
data_loader = torch.utils.data.DataLoader(
|
||||||
dataset=Dataset(xs, ys),
|
dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
|
||||||
batch_size=batch_size,
|
)
|
||||||
shuffle=shuffle,
|
|
||||||
collate_fn=cached_collate_fn)
|
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
def train_discriminator(
|
def train_discriminator(
|
||||||
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
dataset,
|
||||||
epochs=10, batch_size=64, log_interval=10,
|
dataset_fp=None,
|
||||||
save_model=False, cached=False, no_cuda=False):
|
pretrained_model="gpt2-medium",
|
||||||
|
epochs=10,
|
||||||
|
batch_size=64,
|
||||||
|
log_interval=10,
|
||||||
|
save_model=False,
|
||||||
|
cached=False,
|
||||||
|
no_cuda=False,
|
||||||
|
):
|
||||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||||
|
|
||||||
print("Preprocessing {} dataset...".format(dataset))
|
print("Preprocessing {} dataset...".format(dataset))
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
if dataset == "SST":
|
if dataset == "SST":
|
||||||
idx2class = ["positive", "negative", "very positive", "very negative",
|
idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
|
||||||
"neutral"]
|
|
||||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||||
pretrained_model=pretrained_model,
|
|
||||||
cached_mode=cached,
|
|
||||||
device=device
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
text = torchtext_data.Field()
|
text = torchtext_data.Field()
|
||||||
label = torchtext_data.Field(sequential=False)
|
label = torchtext_data.Field(sequential=False)
|
||||||
train_data, val_data, test_data = datasets.SST.splits(
|
train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True,)
|
||||||
text,
|
|
||||||
label,
|
|
||||||
fine_grained=True,
|
|
||||||
train_subtrees=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = []
|
x = []
|
||||||
y = []
|
y = []
|
||||||
for i in trange(len(train_data), ascii=True):
|
for i in trange(len(train_data), ascii=True):
|
||||||
seq = TreebankWordDetokenizer().detokenize(
|
seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
|
||||||
vars(train_data[i])["text"]
|
|
||||||
)
|
|
||||||
seq = discriminator.tokenizer.encode(seq)
|
seq = discriminator.tokenizer.encode(seq)
|
||||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
x.append(seq)
|
x.append(seq)
|
||||||
@@ -284,9 +257,7 @@ def train_discriminator(
|
|||||||
test_x = []
|
test_x = []
|
||||||
test_y = []
|
test_y = []
|
||||||
for i in trange(len(test_data), ascii=True):
|
for i in trange(len(test_data), ascii=True):
|
||||||
seq = TreebankWordDetokenizer().detokenize(
|
seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
|
||||||
vars(test_data[i])["text"]
|
|
||||||
)
|
|
||||||
seq = discriminator.tokenizer.encode(seq)
|
seq = discriminator.tokenizer.encode(seq)
|
||||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
test_x.append(seq)
|
test_x.append(seq)
|
||||||
@@ -306,10 +277,7 @@ def train_discriminator(
|
|||||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||||
pretrained_model=pretrained_model,
|
|
||||||
cached_mode=cached,
|
|
||||||
device=device
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||||
@@ -318,9 +286,7 @@ def train_discriminator(
|
|||||||
try:
|
try:
|
||||||
data.append(eval(line))
|
data.append(eval(line))
|
||||||
except:
|
except:
|
||||||
print("Error evaluating line {}: {}".format(
|
print("Error evaluating line {}: {}".format(i, line))
|
||||||
i, line
|
|
||||||
))
|
|
||||||
continue
|
continue
|
||||||
x = []
|
x = []
|
||||||
y = []
|
y = []
|
||||||
@@ -331,27 +297,20 @@ def train_discriminator(
|
|||||||
seq = discriminator.tokenizer.encode(d["text"])
|
seq = discriminator.tokenizer.encode(d["text"])
|
||||||
|
|
||||||
if len(seq) < max_length_seq:
|
if len(seq) < max_length_seq:
|
||||||
seq = torch.tensor(
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
[50256] + seq, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("Line {} is longer than maximum length {}".format(
|
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||||
i, max_length_seq
|
|
||||||
))
|
|
||||||
continue
|
continue
|
||||||
x.append(seq)
|
x.append(seq)
|
||||||
y.append(d["label"])
|
y.append(d["label"])
|
||||||
except:
|
except:
|
||||||
print("Error evaluating / tokenizing"
|
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
||||||
" line {}, skipping it".format(i))
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
full_dataset = Dataset(x, y)
|
full_dataset = Dataset(x, y)
|
||||||
train_size = int(0.9 * len(full_dataset))
|
train_size = int(0.9 * len(full_dataset))
|
||||||
test_size = len(full_dataset) - train_size
|
test_size = len(full_dataset) - train_size
|
||||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||||
full_dataset, [train_size, test_size]
|
|
||||||
)
|
|
||||||
|
|
||||||
discriminator_meta = {
|
discriminator_meta = {
|
||||||
"class_size": len(idx2class),
|
"class_size": len(idx2class),
|
||||||
@@ -366,10 +325,7 @@ def train_discriminator(
|
|||||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||||
pretrained_model=pretrained_model,
|
|
||||||
cached_mode=cached,
|
|
||||||
device=device
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
x = []
|
x = []
|
||||||
@@ -381,27 +337,20 @@ def train_discriminator(
|
|||||||
seq = discriminator.tokenizer.encode(d["text"])
|
seq = discriminator.tokenizer.encode(d["text"])
|
||||||
|
|
||||||
if len(seq) < max_length_seq:
|
if len(seq) < max_length_seq:
|
||||||
seq = torch.tensor(
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
[50256] + seq, device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("Line {} is longer than maximum length {}".format(
|
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||||
i, max_length_seq
|
|
||||||
))
|
|
||||||
continue
|
continue
|
||||||
x.append(seq)
|
x.append(seq)
|
||||||
y.append(int(np.sum(d["label"]) > 0))
|
y.append(int(np.sum(d["label"]) > 0))
|
||||||
except:
|
except:
|
||||||
print("Error evaluating / tokenizing"
|
print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
|
||||||
" line {}, skipping it".format(i))
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
full_dataset = Dataset(x, y)
|
full_dataset = Dataset(x, y)
|
||||||
train_size = int(0.9 * len(full_dataset))
|
train_size = int(0.9 * len(full_dataset))
|
||||||
test_size = len(full_dataset) - train_size
|
test_size = len(full_dataset) - train_size
|
||||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||||
full_dataset, [train_size, test_size]
|
|
||||||
)
|
|
||||||
|
|
||||||
discriminator_meta = {
|
discriminator_meta = {
|
||||||
"class_size": len(idx2class),
|
"class_size": len(idx2class),
|
||||||
@@ -416,8 +365,7 @@ def train_discriminator(
|
|||||||
# class \t text
|
# class \t text
|
||||||
|
|
||||||
if dataset_fp is None:
|
if dataset_fp is None:
|
||||||
raise ValueError("When generic dataset is selected, "
|
raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
|
||||||
"dataset_fp needs to be specified aswell.")
|
|
||||||
|
|
||||||
classes = set()
|
classes = set()
|
||||||
with open(dataset_fp) as f:
|
with open(dataset_fp) as f:
|
||||||
@@ -430,10 +378,7 @@ def train_discriminator(
|
|||||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
|
||||||
pretrained_model=pretrained_model,
|
|
||||||
cached_mode=cached,
|
|
||||||
device=device
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
x = []
|
x = []
|
||||||
@@ -447,18 +392,11 @@ def train_discriminator(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
seq = discriminator.tokenizer.encode(text)
|
seq = discriminator.tokenizer.encode(text)
|
||||||
if (len(seq) < max_length_seq):
|
if len(seq) < max_length_seq:
|
||||||
seq = torch.tensor(
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
[50256] + seq,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(
|
print("Line {} is longer than maximum length {}".format(i, max_length_seq))
|
||||||
"Line {} is longer than maximum length {}".format(
|
|
||||||
i, max_length_seq
|
|
||||||
))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
x.append(seq)
|
x.append(seq)
|
||||||
@@ -471,10 +409,7 @@ def train_discriminator(
|
|||||||
full_dataset = Dataset(x, y)
|
full_dataset = Dataset(x, y)
|
||||||
train_size = int(0.9 * len(full_dataset))
|
train_size = int(0.9 * len(full_dataset))
|
||||||
test_size = len(full_dataset) - train_size
|
test_size = len(full_dataset) - train_size
|
||||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
||||||
full_dataset,
|
|
||||||
[train_size, test_size]
|
|
||||||
)
|
|
||||||
|
|
||||||
discriminator_meta = {
|
discriminator_meta = {
|
||||||
"class_size": len(idx2class),
|
"class_size": len(idx2class),
|
||||||
@@ -485,9 +420,7 @@ def train_discriminator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("Preprocessed {} data points".format(
|
print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
|
||||||
len(train_dataset) + len(test_dataset))
|
|
||||||
)
|
|
||||||
print("Data preprocessing took: {:.3f}s".format(end - start))
|
print("Data preprocessing took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
if cached:
|
if cached:
|
||||||
@@ -495,30 +428,21 @@ def train_discriminator(
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
train_loader = get_cached_data_loader(
|
train_loader = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)
|
||||||
train_dataset, batch_size, discriminator,
|
|
||||||
shuffle=True, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
test_loader = get_cached_data_loader(
|
test_loader = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)
|
||||||
test_dataset, batch_size, discriminator, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("Building representation cache took: {:.3f}s".format(end - start))
|
print("Building representation cache took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
train_loader = torch.utils.data.DataLoader(
|
||||||
batch_size=batch_size,
|
dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
shuffle=True,
|
)
|
||||||
collate_fn=collate_fn)
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
|
||||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
|
|
||||||
if save_model:
|
if save_model:
|
||||||
with open("{}_classifier_head_meta.json".format(dataset),
|
with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
|
||||||
"w") as meta_file:
|
|
||||||
json.dump(discriminator_meta, meta_file)
|
json.dump(discriminator_meta, meta_file)
|
||||||
|
|
||||||
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
||||||
@@ -533,56 +457,61 @@ def train_discriminator(
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
log_interval=log_interval,
|
log_interval=log_interval,
|
||||||
device=device
|
device=device,
|
||||||
)
|
|
||||||
evaluate_performance(
|
|
||||||
data_loader=test_loader,
|
|
||||||
discriminator=discriminator,
|
|
||||||
device=device
|
|
||||||
)
|
)
|
||||||
|
evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("Epoch took: {:.3f}s".format(end - start))
|
print("Epoch took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
print("\nExample prediction")
|
print("\nExample prediction")
|
||||||
predict(example_sentence, discriminator, idx2class,
|
predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
|
||||||
cached=cached, device=device)
|
|
||||||
|
|
||||||
if save_model:
|
if save_model:
|
||||||
# torch.save(discriminator.state_dict(),
|
# torch.save(discriminator.state_dict(),
|
||||||
# "{}_discriminator_{}.pt".format(
|
# "{}_discriminator_{}.pt".format(
|
||||||
# args.dataset, epoch + 1
|
# args.dataset, epoch + 1
|
||||||
# ))
|
# ))
|
||||||
torch.save(discriminator.get_classifier().state_dict(),
|
torch.save(
|
||||||
"{}_classifier_head_epoch_{}.pt".format(dataset,
|
discriminator.get_classifier().state_dict(),
|
||||||
epoch + 1))
|
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Train a discriminator on top of GPT-2 representations")
|
||||||
description="Train a discriminator on top of GPT-2 representations")
|
parser.add_argument(
|
||||||
parser.add_argument("--dataset", type=str, default="SST",
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="SST",
|
||||||
choices=("SST", "clickbait", "toxic", "generic"),
|
choices=("SST", "clickbait", "toxic", "generic"),
|
||||||
help="dataset to train the discriminator on."
|
help="dataset to train the discriminator on."
|
||||||
"In case of generic, the dataset is expected"
|
"In case of generic, the dataset is expected"
|
||||||
"to be a TSBV file with structure: class \\t text")
|
"to be a TSBV file with structure: class \\t text",
|
||||||
parser.add_argument("--dataset_fp", type=str, default="",
|
)
|
||||||
help="File path of the dataset to use. "
|
parser.add_argument(
|
||||||
"Needed only in case of generic datadset")
|
"--dataset_fp",
|
||||||
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
|
type=str,
|
||||||
help="Pretrained model to use as encoder")
|
default="",
|
||||||
parser.add_argument("--epochs", type=int, default=10, metavar="N",
|
help="File path of the dataset to use. " "Needed only in case of generic datadset",
|
||||||
help="Number of training epochs")
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=64, metavar="N",
|
parser.add_argument(
|
||||||
help="input batch size for training (default: 64)")
|
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
|
||||||
parser.add_argument("--log_interval", type=int, default=10, metavar="N",
|
)
|
||||||
help="how many batches to wait before logging training status")
|
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="Number of training epochs")
|
||||||
parser.add_argument("--save_model", action="store_true",
|
parser.add_argument(
|
||||||
help="whether to save the model")
|
"--batch_size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
|
||||||
parser.add_argument("--cached", action="store_true",
|
)
|
||||||
help="whether to cache the input representations")
|
parser.add_argument(
|
||||||
parser.add_argument("--no_cuda", action="store_true",
|
"--log_interval",
|
||||||
help="use to turn off cuda")
|
type=int,
|
||||||
|
default=10,
|
||||||
|
metavar="N",
|
||||||
|
help="how many batches to wait before logging training status",
|
||||||
|
)
|
||||||
|
parser.add_argument("--save_model", action="store_true", help="whether to save the model")
|
||||||
|
parser.add_argument("--cached", action="store_true", help="whether to cache the input representations")
|
||||||
|
parser.add_argument("--no_cuda", action="store_true", help="use to turn off cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
train_discriminator(**(vars(args)))
|
train_discriminator(**(vars(args)))
|
||||||
|
|||||||
@@ -32,10 +32,18 @@ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subse
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME,
|
from transformers import (
|
||||||
BertConfig, BertForSequenceClassification, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
BertConfig,
|
||||||
XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer)
|
BertForSequenceClassification,
|
||||||
|
BertTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForSequenceClassification,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
|
XLNetForSequenceClassification,
|
||||||
|
XLNetTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
from run_glue import set_seed, load_and_cache_examples, ALL_MODELS, MODEL_CLASSES
|
from run_glue import set_seed, load_and_cache_examples, ALL_MODELS, MODEL_CLASSES
|
||||||
|
|
||||||
@@ -63,7 +71,9 @@ def print_2d_tensor(tensor):
|
|||||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
||||||
|
|
||||||
|
|
||||||
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
|
def compute_heads_importance(
|
||||||
|
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
|
||||||
|
):
|
||||||
""" This method shows how to compute:
|
""" This method shows how to compute:
|
||||||
- head attention entropy
|
- head attention entropy
|
||||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||||
@@ -85,8 +95,14 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
|||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||||
outputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, head_mask=head_mask)
|
outputs = model(
|
||||||
loss, logits, all_attentions = outputs[0], outputs[1], outputs[-1] # Loss and logits are the first, attention the last
|
input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, head_mask=head_mask
|
||||||
|
)
|
||||||
|
loss, logits, all_attentions = (
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
outputs[-1],
|
||||||
|
) # Loss and logits are the first, attention the last
|
||||||
loss.backward() # Backpropagate to populate the gradients in the head mask
|
loss.backward() # Backpropagate to populate the gradients in the head mask
|
||||||
|
|
||||||
if compute_entropy:
|
if compute_entropy:
|
||||||
@@ -120,8 +136,8 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
|||||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||||
|
|
||||||
# Print/save matrices
|
# Print/save matrices
|
||||||
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy.detach().cpu().numpy())
|
np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
|
||||||
np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance.detach().cpu().numpy())
|
np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
|
||||||
|
|
||||||
logger.info("Attention entropies")
|
logger.info("Attention entropies")
|
||||||
print_2d_tensor(attn_entropy)
|
print_2d_tensor(attn_entropy)
|
||||||
@@ -129,7 +145,9 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
|
|||||||
print_2d_tensor(head_importance)
|
print_2d_tensor(head_importance)
|
||||||
logger.info("Head ranked by importance scores")
|
logger.info("Head ranked by importance scores")
|
||||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device)
|
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
||||||
|
head_importance.numel(), device=args.device
|
||||||
|
)
|
||||||
head_ranks = head_ranks.view_as(head_importance)
|
head_ranks = head_ranks.view_as(head_importance)
|
||||||
print_2d_tensor(head_ranks)
|
print_2d_tensor(head_ranks)
|
||||||
|
|
||||||
@@ -152,7 +170,7 @@ def mask_heads(args, model, eval_dataloader):
|
|||||||
while current_score >= original_score * args.masking_threshold:
|
while current_score >= original_score * args.masking_threshold:
|
||||||
head_mask = new_head_mask.clone() # save current head mask
|
head_mask = new_head_mask.clone() # save current head mask
|
||||||
# heads from least important to most - keep only not-masked heads
|
# heads from least important to most - keep only not-masked heads
|
||||||
head_importance[head_mask == 0.0] = float('Inf')
|
head_importance[head_mask == 0.0] = float("Inf")
|
||||||
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||||
|
|
||||||
if len(current_heads_to_mask) <= num_to_mask:
|
if len(current_heads_to_mask) <= num_to_mask:
|
||||||
@@ -167,14 +185,21 @@ def mask_heads(args, model, eval_dataloader):
|
|||||||
print_2d_tensor(new_head_mask)
|
print_2d_tensor(new_head_mask)
|
||||||
|
|
||||||
# Compute metric and head importance again
|
# Compute metric and head importance again
|
||||||
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
|
_, head_importance, preds, labels = compute_heads_importance(
|
||||||
|
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
||||||
|
)
|
||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||||
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
|
logger.info(
|
||||||
|
"Masking: current score: %f, remaning heads %d (%.1f percents)",
|
||||||
|
current_score,
|
||||||
|
new_head_mask.sum(),
|
||||||
|
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Final head mask")
|
logger.info("Final head mask")
|
||||||
print_2d_tensor(head_mask)
|
print_2d_tensor(head_mask)
|
||||||
np.save(os.path.join(args.output_dir, 'head_mask.npy'), head_mask.detach().cpu().numpy())
|
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
||||||
|
|
||||||
return head_mask
|
return head_mask
|
||||||
|
|
||||||
@@ -186,8 +211,9 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
# Try pruning and test time speedup
|
# Try pruning and test time speedup
|
||||||
# Pruning is like masking but we actually remove the masked weights
|
# Pruning is like masking but we actually remove the masked weights
|
||||||
before_time = datetime.now()
|
before_time = datetime.now()
|
||||||
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
_, _, preds, labels = compute_heads_importance(
|
||||||
compute_entropy=False, compute_importance=False, head_mask=head_mask)
|
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
||||||
|
)
|
||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
score_masking = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
score_masking = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||||
original_time = datetime.now() - before_time
|
original_time = datetime.now() - before_time
|
||||||
@@ -199,13 +225,19 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||||
|
|
||||||
before_time = datetime.now()
|
before_time = datetime.now()
|
||||||
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
|
_, _, preds, labels = compute_heads_importance(
|
||||||
compute_entropy=False, compute_importance=False, head_mask=None)
|
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
|
||||||
|
)
|
||||||
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
||||||
score_pruning = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
score_pruning = compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
||||||
new_time = datetime.now() - before_time
|
new_time = datetime.now() - before_time
|
||||||
|
|
||||||
logger.info("Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params/original_num_params * 100)
|
logger.info(
|
||||||
|
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
||||||
|
original_num_params,
|
||||||
|
pruned_num_params,
|
||||||
|
pruned_num_params / original_num_params * 100,
|
||||||
|
)
|
||||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||||
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
|
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
|
||||||
|
|
||||||
@@ -213,59 +245,107 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
"--data_dir",
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(
|
type=str,
|
||||||
ALL_MODELS))
|
required=True,
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
)
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name_or_path")
|
"--config_name",
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
default="",
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path")
|
type=str,
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
help="Pretrained config name or path if not the same as model_name_or_path",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
)
|
||||||
parser.add_argument("--data_subset", type=int, default=-1,
|
parser.add_argument(
|
||||||
help="If > 0: limit the data to a subset of data_subset instances.")
|
"--tokenizer_name",
|
||||||
parser.add_argument("--overwrite_output_dir", action='store_true',
|
default="",
|
||||||
help="Whether to overwrite data in output directory")
|
type=str,
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||||
help="Overwrite the cached training and evaluation sets")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--dont_normalize_importance_by_layer", action='store_true',
|
parser.add_argument(
|
||||||
help="Don't normalize importance score by layers")
|
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
||||||
parser.add_argument("--dont_normalize_global_importance", action='store_true',
|
)
|
||||||
help="Don't normalize all importance scores between 0 and 1")
|
parser.add_argument(
|
||||||
|
"--dont_normalize_global_importance",
|
||||||
|
action="store_true",
|
||||||
|
help="Don't normalize all importance scores between 0 and 1",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--try_masking", action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to try to mask head until a threshold of accuracy.")
|
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
||||||
parser.add_argument("--masking_threshold", default=0.9, type=float,
|
)
|
||||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).")
|
parser.add_argument(
|
||||||
parser.add_argument("--masking_amount", default=0.1, type=float,
|
"--masking_threshold",
|
||||||
help="Amount to heads to masking at each masking step.")
|
default=0.9,
|
||||||
parser.add_argument("--metric_name", default="acc", type=str,
|
type=float,
|
||||||
help="Metric to use for head masking.")
|
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
||||||
|
)
|
||||||
|
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||||
|
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||||
"Sequences longer than this will be truncated, sequences shorter padded.")
|
"Sequences longer than this will be truncated, sequences shorter padded.",
|
||||||
|
)
|
||||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||||
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -278,7 +358,7 @@ def main():
|
|||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
args.device = torch.device("cuda", args.local_rank)
|
args.device = torch.device("cuda", args.local_rank)
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
torch.distributed.init_process_group(backend='nccl') # Initializes the distributed backend
|
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||||
@@ -306,17 +386,23 @@ def main():
|
|||||||
args.model_type = key # take the first match in model types
|
args.model_type = key # take the first match in model types
|
||||||
break
|
break
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
finetuning_task=args.task_name,
|
finetuning_task=args.task_name,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
)
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -324,14 +410,14 @@ def main():
|
|||||||
# Distributed and parallel training
|
# Distributed and parallel training
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
elif args.n_gpu > 1:
|
elif args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Print/save training arguments
|
# Print/save training arguments
|
||||||
torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
# Prepare dataset for the GLUE task
|
# Prepare dataset for the GLUE task
|
||||||
@@ -341,11 +427,9 @@ def main():
|
|||||||
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
|
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
|
||||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
|
||||||
|
|
||||||
|
|
||||||
# Compute head entropy and importance score
|
# Compute head entropy and importance score
|
||||||
compute_heads_importance(args, model, eval_dataloader)
|
compute_heads_importance(args, model, eval_dataloader)
|
||||||
|
|
||||||
|
|
||||||
# Try head masking (set heads to zero until the score goes under a threshole)
|
# Try head masking (set heads to zero until the score goes under a threshole)
|
||||||
# and head pruning (remove masked heads and see the effect on the network)
|
# and head pruning (remove masked heads and see the effect on the network)
|
||||||
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||||
@@ -353,5 +437,5 @@ def main():
|
|||||||
prune_heads(args, model, eval_dataloader, head_mask)
|
prune_heads(args, model, eval_dataloader, head_mask)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -33,9 +33,7 @@ from transformers import XLMWithLMHeadModel, XLMTokenizer
|
|||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,6 +69,7 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Functions to prepare models' input
|
# Functions to prepare models' input
|
||||||
#
|
#
|
||||||
@@ -78,15 +77,11 @@ def set_seed(args):
|
|||||||
|
|
||||||
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||||
if args.temperature > 0.7:
|
if args.temperature > 0.7:
|
||||||
logger.info(
|
logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
|
||||||
"CTRL typically works better with lower temperatures (and lower top_k)."
|
|
||||||
)
|
|
||||||
|
|
||||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||||
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
|
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
|
||||||
logger.info(
|
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||||
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
|
||||||
)
|
|
||||||
return prompt_text
|
return prompt_text
|
||||||
|
|
||||||
|
|
||||||
@@ -102,11 +97,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
|||||||
else:
|
else:
|
||||||
language = None
|
language = None
|
||||||
while language not in available_languages:
|
while language not in available_languages:
|
||||||
language = input(
|
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
|
||||||
"Using XLM. Select language in "
|
|
||||||
+ str(list(available_languages))
|
|
||||||
+ " >>> "
|
|
||||||
)
|
|
||||||
# kwargs["language"] = tokenizer.lang2id[language]
|
# kwargs["language"] = tokenizer.lang2id[language]
|
||||||
|
|
||||||
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||||
@@ -148,17 +139,34 @@ def adjust_length_to_model(length, max_sequence_length):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
"--model_type",
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--prompt", type=str, default="")
|
parser.add_argument("--prompt", type=str, default="")
|
||||||
parser.add_argument("--length", type=int, default=20)
|
parser.add_argument("--length", type=int, default=20)
|
||||||
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling")
|
parser.add_argument(
|
||||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
|
||||||
|
)
|
||||||
parser.add_argument("--k", type=int, default=0)
|
parser.add_argument("--k", type=int, default=0)
|
||||||
parser.add_argument("--p", type=float, default=0.9)
|
parser.add_argument("--p", type=float, default=0.9)
|
||||||
|
|
||||||
@@ -169,9 +177,7 @@ def main():
|
|||||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.device = torch.device(
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
|
||||||
)
|
|
||||||
args.n_gpu = torch.cuda.device_count()
|
args.n_gpu = torch.cuda.device_count()
|
||||||
|
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -181,17 +187,13 @@ def main():
|
|||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(
|
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
|
||||||
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||||
model = model_class.from_pretrained(args.model_name_or_path)
|
model = model_class.from_pretrained(args.model_name_or_path)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
args.length = adjust_length_to_model(
|
args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
|
||||||
args.length, max_sequence_length=model.config.max_position_embeddings
|
|
||||||
)
|
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
|
|
||||||
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||||
@@ -201,7 +203,7 @@ def main():
|
|||||||
if requires_preprocessing:
|
if requires_preprocessing:
|
||||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||||
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||||
|
|
||||||
output_sequences = model.generate(
|
output_sequences = model.generate(
|
||||||
input_ids=encoded_prompt,
|
input_ids=encoded_prompt,
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ import json
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -37,13 +36,18 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (
|
||||||
BertForSequenceClassification, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
|
BertConfig,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BertTokenizer,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
XLMConfig, XLMForSequenceClassification,
|
XLMConfig,
|
||||||
XLMTokenizer, XLNetConfig,
|
XLMForSequenceClassification,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
@@ -66,17 +70,22 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
|
ALL_MODELS = sum(
|
||||||
RobertaConfig, DistilBertConfig)), ())
|
(
|
||||||
|
tuple(conf.pretrained_config_archive_map.keys())
|
||||||
|
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||||
|
),
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||||
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
||||||
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -104,20 +113,27 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
|
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
@@ -132,17 +148,21 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -152,7 +172,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
@@ -163,7 +183,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(
|
||||||
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
|
)
|
||||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
@@ -176,11 +198,11 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
'attention_mask': batch[1],
|
if args.model_type != "distilbert":
|
||||||
'labels': batch[3]}
|
inputs["token_type_ids"] = (
|
||||||
if args.model_type != 'distilbert':
|
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
|
|
||||||
@@ -209,36 +231,40 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
logs = {}
|
logs = {}
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
eval_key = 'eval_{}'.format(key)
|
eval_key = "eval_{}".format(key)
|
||||||
logs[eval_key] = value
|
logs[eval_key] = value
|
||||||
|
|
||||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||||
learning_rate_scalar = scheduler.get_lr()[0]
|
learning_rate_scalar = scheduler.get_lr()[0]
|
||||||
logs['learning_rate'] = learning_rate_scalar
|
logs["learning_rate"] = learning_rate_scalar
|
||||||
logs['loss'] = loss_scalar
|
logs["loss"] = loss_scalar
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
for key, value in logs.items():
|
for key, value in logs.items():
|
||||||
tb_writer.add_scalar(key, value, global_step)
|
tb_writer.add_scalar(key, value, global_step)
|
||||||
print(json.dumps({**logs, **{'step': global_step}}))
|
print(json.dumps({**logs, **{"step": global_step}}))
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -257,7 +283,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
def evaluate(args, model, tokenizer, prefix=""):
|
def evaluate(args, model, tokenizer, prefix=""):
|
||||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||||
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
|
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||||
@@ -288,11 +314,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
'attention_mask': batch[1],
|
if args.model_type != "distilbert":
|
||||||
'labels': batch[3]}
|
inputs["token_type_ids"] = (
|
||||||
if args.model_type != 'distilbert':
|
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
@@ -300,10 +326,10 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
nb_eval_steps += 1
|
nb_eval_steps += 1
|
||||||
if preds is None:
|
if preds is None:
|
||||||
preds = logits.detach().cpu().numpy()
|
preds = logits.detach().cpu().numpy()
|
||||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||||
else:
|
else:
|
||||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||||
|
|
||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
if args.output_mode == "classification":
|
if args.output_mode == "classification":
|
||||||
@@ -330,29 +356,36 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
processor = processors[task]()
|
processor = processors[task]()
|
||||||
output_mode = output_modes[task]
|
output_mode = output_modes[task]
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
'dev' if evaluate else 'train',
|
args.data_dir,
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
"cached_{}_{}_{}_{}".format(
|
||||||
|
"dev" if evaluate else "train",
|
||||||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task)))
|
str(task),
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
|
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
||||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
examples = (
|
||||||
features = convert_examples_to_features(examples,
|
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||||
|
)
|
||||||
|
features = convert_examples_to_features(
|
||||||
|
examples,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
label_list=label_list,
|
label_list=label_list,
|
||||||
max_length=args.max_seq_length,
|
max_length=args.max_seq_length,
|
||||||
output_mode=output_mode,
|
output_mode=output_mode,
|
||||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
|
||||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||||
)
|
)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
@@ -378,90 +411,149 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
"--data_dir",
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
type=str,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
required=True,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
)
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
parser.add_argument(
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
"--model_type",
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
"--tokenizer_name",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
default="",
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.")
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
help="Rul evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument(
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
type=int,
|
||||||
help="The initial learning rate for Adam.")
|
default=1,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
help="Weight decay if we apply some.")
|
)
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Avoid using CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
)
|
||||||
help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -473,16 +565,24 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -502,17 +602,23 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
finetuning_task=args.task_name,
|
finetuning_task=args.task_name,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -521,14 +627,12 @@ def main():
|
|||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -538,36 +642,39 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -42,37 +42,55 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
|
from transformers import (
|
||||||
BertConfig, BertForMaskedLM, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
AdamW,
|
||||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
get_linear_schedule_with_warmup,
|
||||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
|
BertConfig,
|
||||||
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer,
|
BertForMaskedLM,
|
||||||
CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
BertTokenizer,
|
||||||
|
GPT2Config,
|
||||||
|
GPT2LMHeadModel,
|
||||||
|
GPT2Tokenizer,
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
OpenAIGPTLMHeadModel,
|
||||||
|
OpenAIGPTTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
CamembertConfig,
|
||||||
|
CamembertForMaskedLM,
|
||||||
|
CamembertTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||||
'camembert': (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, tokenizer, args, file_path='train', block_size=512):
|
def __init__(self, tokenizer, args, file_path="train", block_size=512):
|
||||||
assert os.path.isfile(file_path)
|
assert os.path.isfile(file_path)
|
||||||
directory, filename = os.path.split(file_path)
|
directory, filename = os.path.split(file_path)
|
||||||
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
|
cached_features_file = os.path.join(
|
||||||
|
directory, args.model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
with open(cached_features_file, 'rb') as handle:
|
with open(cached_features_file, "rb") as handle:
|
||||||
self.examples = pickle.load(handle)
|
self.examples = pickle.load(handle)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", directory)
|
logger.info("Creating features from dataset file at %s", directory)
|
||||||
@@ -90,7 +108,7 @@ class TextDataset(Dataset):
|
|||||||
# can change this behavior by adding (model specific) padding.
|
# can change this behavior by adding (model specific) padding.
|
||||||
|
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
with open(cached_features_file, 'wb') as handle:
|
with open(cached_features_file, "wb") as handle:
|
||||||
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -101,7 +119,12 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||||
dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
dataset = TextDataset(
|
||||||
|
tokenizer,
|
||||||
|
args,
|
||||||
|
file_path=args.eval_data_file if evaluate else args.train_data_file,
|
||||||
|
block_size=args.block_size,
|
||||||
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -120,7 +143,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check if we should delete older checkpoint(s)
|
# Check if we should delete older checkpoint(s)
|
||||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
|
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
|
||||||
if len(glob_checkpoints) <= args.save_total_limit:
|
if len(glob_checkpoints) <= args.save_total_limit:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -129,7 +152,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
|||||||
if use_mtime:
|
if use_mtime:
|
||||||
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
||||||
else:
|
else:
|
||||||
regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
|
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
|
||||||
if regex_match and regex_match.groups():
|
if regex_match and regex_match.groups():
|
||||||
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
||||||
|
|
||||||
@@ -147,7 +170,9 @@ def mask_tokens(inputs, tokenizer, args):
|
|||||||
labels = inputs.clone()
|
labels = inputs.clone()
|
||||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
||||||
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
|
special_tokens_mask = [
|
||||||
|
tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||||
|
]
|
||||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
@@ -181,19 +206,26 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
@@ -208,17 +240,21 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -228,7 +264,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
@@ -239,11 +275,13 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
|
|
||||||
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
||||||
model_to_resize.resize_token_embeddings(len(tokenizer))
|
model_to_resize.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(
|
||||||
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
|
)
|
||||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
@@ -285,31 +323,35 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
checkpoint_prefix = 'checkpoint'
|
checkpoint_prefix = "checkpoint"
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
_rotate_checkpoints(args, checkpoint_prefix)
|
_rotate_checkpoints(args, checkpoint_prefix)
|
||||||
|
|
||||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -365,9 +407,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
perplexity = torch.exp(torch.tensor(eval_loss))
|
perplexity = torch.exp(torch.tensor(eval_loss))
|
||||||
|
|
||||||
result = {
|
result = {"perplexity": perplexity}
|
||||||
"perplexity": perplexity
|
|
||||||
}
|
|
||||||
|
|
||||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
@@ -383,107 +423,167 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input training data file (a text file).")
|
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
)
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--eval_data_file", default=None, type=str,
|
parser.add_argument(
|
||||||
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
"--eval_data_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--model_type", default="bert", type=str,
|
parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.")
|
||||||
help="The model architecture to be fine-tuned.")
|
parser.add_argument(
|
||||||
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
|
"--model_name_or_path",
|
||||||
help="The model checkpoint for weights initialization.")
|
default="bert-base-cased",
|
||||||
|
type=str,
|
||||||
|
help="The model checkpoint for weights initialization.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--mlm", action='store_true',
|
parser.add_argument(
|
||||||
help="Train with masked-language modeling loss instead of language modeling.")
|
"--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
|
||||||
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
)
|
||||||
help="Ratio of tokens to mask for masked language modeling loss")
|
parser.add_argument(
|
||||||
|
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
"--config_name",
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
default="",
|
||||||
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
type=str,
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
help="Optional pretrained config name or path if not the same as model_name_or_path",
|
||||||
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
)
|
||||||
parser.add_argument("--block_size", default=-1, type=int,
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block_size",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
help="Optional input sequence length after tokenization."
|
help="Optional input sequence length after tokenization."
|
||||||
"The training dataset will be truncated in block of this size for training."
|
"The training dataset will be truncated in block of this size for training."
|
||||||
"Default to the model max input length for single sentence inputs (take into account special tokens).")
|
"Default to the model max input length for single sentence inputs (take into account special tokens).",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
||||||
help="Run evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
|
"--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument(
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
type=int,
|
||||||
help="The initial learning rate for Adam.")
|
default=1,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
help="Weight decay if we apply some.")
|
)
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--save_total_limit",
|
||||||
parser.add_argument('--save_total_limit', type=int, default=None,
|
type=int,
|
||||||
help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
|
default=None,
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
)
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
parser.add_argument(
|
||||||
help="Avoid using CUDA when available")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
action="store_true",
|
||||||
help="Overwrite the content of the output directory")
|
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
parser.add_argument(
|
||||||
help="random seed for initialization")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
)
|
||||||
help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
||||||
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
raise ValueError(
|
||||||
"flag (masked language modeling).")
|
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||||
|
"flag (masked language modeling)."
|
||||||
|
)
|
||||||
if args.eval_data_file is None and args.do_eval:
|
if args.eval_data_file is None and args.do_eval:
|
||||||
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
raise ValueError(
|
||||||
"or remove the --do_eval argument.")
|
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
||||||
|
"or remove the --do_eval argument."
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -495,16 +595,24 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -514,18 +622,26 @@ def main():
|
|||||||
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||||
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
if args.block_size <= 0:
|
if args.block_size <= 0:
|
||||||
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
|
args.block_size = (
|
||||||
|
tokenizer.max_len_single_sentence
|
||||||
|
) # Our input block size will be the max possible for the model
|
||||||
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
model = model_class.from_pretrained(
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@@ -546,7 +662,6 @@ def main():
|
|||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -556,35 +671,38 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -37,34 +36,38 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (
|
||||||
BertForMultipleChoice, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
XLNetConfig, XLNetForMultipleChoice,
|
BertConfig,
|
||||||
XLNetTokenizer, RobertaConfig,
|
BertForMultipleChoice,
|
||||||
RobertaForMultipleChoice, RobertaTokenizer)
|
BertTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
|
XLNetForMultipleChoice,
|
||||||
|
XLNetTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForMultipleChoice,
|
||||||
|
RobertaTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
from utils_multiple_choice import (convert_examples_to_features, processors)
|
from utils_multiple_choice import convert_examples_to_features, processors
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
|
ALL_MODELS = sum(
|
||||||
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ()
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
|
"xlnet": (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer)
|
"roberta": (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def select_field(features, field):
|
def select_field(features, field):
|
||||||
return [
|
return [[choice[field] for choice in feature.choices_features] for feature in features]
|
||||||
[
|
|
||||||
choice[field]
|
|
||||||
for choice in feature.choices_features
|
|
||||||
]
|
|
||||||
for feature in features
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def simple_accuracy(preds, labels):
|
def simple_accuracy(preds, labels):
|
||||||
@@ -95,13 +98,18 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -115,17 +123,21 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -141,10 +153,14 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1],
|
"input_ids": batch[0],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
"attention_mask": batch[1],
|
||||||
'labels': batch[3]}
|
"token_type_ids": batch[2]
|
||||||
|
if args.model_type in ["bert", "xlnet"]
|
||||||
|
else None, # XLM don't use segment_ids
|
||||||
|
"labels": batch[3],
|
||||||
|
}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
|
|
||||||
@@ -171,10 +187,12 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
if results["eval_acc"] > best_dev_acc:
|
if results["eval_acc"] > best_dev_acc:
|
||||||
best_dev_acc = results["eval_acc"]
|
best_dev_acc = results["eval_acc"]
|
||||||
best_dev_loss = results["eval_loss"]
|
best_dev_loss = results["eval_loss"]
|
||||||
@@ -182,22 +200,33 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if args.do_test:
|
if args.do_test:
|
||||||
results_test = evaluate(args, model, tokenizer, test=True)
|
results_test = evaluate(args, model, tokenizer, test=True)
|
||||||
for key, value in results_test.items():
|
for key, value in results_test.items():
|
||||||
tb_writer.add_scalar('test_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("test_{}".format(key), value, global_step)
|
||||||
logger.info("test acc: %s, loss: %s, global steps: %s", str(results_test['eval_acc']), str(results_test['eval_loss']), str(global_step))
|
logger.info(
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
"test acc: %s, loss: %s, global steps: %s",
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
str(results_test["eval_acc"]),
|
||||||
logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step))
|
str(results_test["eval_loss"]),
|
||||||
|
str(global_step),
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
|
logger.info(
|
||||||
|
"Average loss: %s at global step: %s",
|
||||||
|
str((tr_loss - logging_loss) / args.logging_steps),
|
||||||
|
str(global_step),
|
||||||
|
)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_vocabulary(output_dir)
|
tokenizer.save_vocabulary(output_dir)
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -246,10 +275,14 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1],
|
"input_ids": batch[0],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
"attention_mask": batch[1],
|
||||||
'labels': batch[3]}
|
"token_type_ids": batch[2]
|
||||||
|
if args.model_type in ["bert", "xlnet"]
|
||||||
|
else None, # XLM don't use segment_ids
|
||||||
|
"labels": batch[3],
|
||||||
|
}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
@@ -257,10 +290,10 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
|||||||
nb_eval_steps += 1
|
nb_eval_steps += 1
|
||||||
if preds is None:
|
if preds is None:
|
||||||
preds = logits.detach().cpu().numpy()
|
preds = logits.detach().cpu().numpy()
|
||||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||||
else:
|
else:
|
||||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||||
|
|
||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
preds = np.argmax(preds, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
@@ -273,8 +306,14 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
|||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
|
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
|
||||||
writer.write("model =%s\n" % str(args.model_name_or_path))
|
writer.write("model =%s\n" % str(args.model_name_or_path))
|
||||||
writer.write("total batch size=%d\n" % (args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
|
writer.write(
|
||||||
(torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
|
"total batch size=%d\n"
|
||||||
|
% (
|
||||||
|
args.per_gpu_train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1)
|
||||||
|
)
|
||||||
|
)
|
||||||
writer.write("train num epochs=%d\n" % args.num_train_epochs)
|
writer.write("train num epochs=%d\n" % args.num_train_epochs)
|
||||||
writer.write("fp16 =%s\n" % args.fp16)
|
writer.write("fp16 =%s\n" % args.fp16)
|
||||||
writer.write("max seq length =%d\n" % args.max_seq_length)
|
writer.write("max seq length =%d\n" % args.max_seq_length)
|
||||||
@@ -291,17 +330,21 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
|||||||
processor = processors[task]()
|
processor = processors[task]()
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
if evaluate:
|
if evaluate:
|
||||||
cached_mode = 'dev'
|
cached_mode = "dev"
|
||||||
elif test:
|
elif test:
|
||||||
cached_mode = 'test'
|
cached_mode = "test"
|
||||||
else:
|
else:
|
||||||
cached_mode = 'train'
|
cached_mode = "train"
|
||||||
assert (evaluate == True and test == True) == False
|
assert (evaluate == True and test == True) == False
|
||||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
"cached_{}_{}_{}_{}".format(
|
||||||
cached_mode,
|
cached_mode,
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task)))
|
str(task),
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
@@ -320,8 +363,8 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
|||||||
label_list,
|
label_list,
|
||||||
args.max_seq_length,
|
args.max_seq_length,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
|
||||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0
|
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||||
)
|
)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
@@ -331,9 +374,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
|||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
# Convert to Tensors and build dataset
|
# Convert to Tensors and build dataset
|
||||||
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
|
all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
|
||||||
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
|
all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
|
all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
|
||||||
all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)
|
all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||||
|
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
@@ -344,91 +387,150 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
"--data_dir",
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
type=str,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
required=True,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
)
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
parser.add_argument(
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
"--model_type",
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
"--tokenizer_name",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
default="",
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.")
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_test", action="store_true", help="Whether to run test on the test set")
|
||||||
parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set')
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
||||||
help="Run evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument(
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
type=int,
|
||||||
help="The initial learning rate for Adam.")
|
default=1,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
help="Weight deay if we apply some.")
|
)
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Avoid using CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
)
|
||||||
help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -440,16 +542,24 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -468,17 +578,23 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
finetuning_task=args.task_name,
|
finetuning_task=args.task_name,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -494,7 +610,6 @@ def main():
|
|||||||
global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -504,19 +619,20 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
@@ -524,17 +640,19 @@ def main():
|
|||||||
args.output_dir = args.model_name_or_path
|
args.output_dir = args.model_name_or_path
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
if args.do_test and args.local_rank in [-1, 0]:
|
if args.do_test and args.local_rank in [-1, 0]:
|
||||||
@@ -546,13 +664,13 @@ def main():
|
|||||||
# logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
# logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix, test=True)
|
result = evaluate(args, model, tokenizer, prefix=prefix, test=True)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
if best_steps:
|
if best_steps:
|
||||||
logger.info("best steps of eval acc is the following checkpoints: %s", best_steps)
|
logger.info("best steps of eval acc is the following checkpoints: %s", best_steps)
|
||||||
|
|||||||
@@ -43,9 +43,12 @@ from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification, XLM
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig,
|
(
|
||||||
CamembertConfig, XLMRobertaConfig)),
|
tuple(conf.pretrained_config_archive_map.keys())
|
||||||
())
|
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
|
||||||
|
),
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||||
@@ -82,18 +85,24 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
{
|
||||||
"weight_decay": args.weight_decay},
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
@@ -108,18 +117,21 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -129,7 +141,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
@@ -140,7 +152,9 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(
|
||||||
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
|
)
|
||||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
@@ -153,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {"input_ids": batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
"attention_mask": batch[1],
|
|
||||||
"labels": batch[3]}
|
|
||||||
if args.model_type != "distilbert":
|
if args.model_type != "distilbert":
|
||||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
inputs["token_type_ids"] = (
|
||||||
|
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||||
|
) # XLM and RoBERTa don"t use segment_ids
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
@@ -187,7 +201,9 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
@@ -200,15 +216,17 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -249,11 +267,11 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {"input_ids": batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
"attention_mask": batch[1],
|
|
||||||
"labels": batch[3]}
|
|
||||||
if args.model_type != "distilbert":
|
if args.model_type != "distilbert":
|
||||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
inputs["token_type_ids"] = (
|
||||||
|
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||||
|
) # XLM and RoBERTa don"t use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
@@ -287,7 +305,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
|||||||
"loss": eval_loss,
|
"loss": eval_loss,
|
||||||
"precision": precision_score(out_label_list, preds_list),
|
"precision": precision_score(out_label_list, preds_list),
|
||||||
"recall": recall_score(out_label_list, preds_list),
|
"recall": recall_score(out_label_list, preds_list),
|
||||||
"f1": f1_score(out_label_list, preds_list)
|
"f1": f1_score(out_label_list, preds_list),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("***** Eval results %s *****", prefix)
|
logger.info("***** Eval results %s *****", prefix)
|
||||||
@@ -302,16 +320,23 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
|||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format(mode,
|
cached_features_file = os.path.join(
|
||||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
args.data_dir,
|
||||||
str(args.max_seq_length)))
|
"cached_{}_{}_{}".format(
|
||||||
|
mode, list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length)
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
examples = read_examples_from_file(args.data_dir, mode)
|
examples = read_examples_from_file(args.data_dir, mode)
|
||||||
features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer,
|
features = convert_examples_to_features(
|
||||||
|
examples,
|
||||||
|
labels,
|
||||||
|
args.max_seq_length,
|
||||||
|
tokenizer,
|
||||||
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
||||||
# xlnet has a cls token at the end
|
# xlnet has a cls token at the end
|
||||||
cls_token=tokenizer.cls_token,
|
cls_token=tokenizer.cls_token,
|
||||||
@@ -323,7 +348,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
|||||||
# pad on the left for xlnet
|
# pad on the left for xlnet
|
||||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||||
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||||
pad_token_label_id=pad_token_label_id
|
pad_token_label_id=pad_token_label_id,
|
||||||
)
|
)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
@@ -346,95 +371,151 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.")
|
"--data_dir",
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
type=str,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
required=True,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
)
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--labels", default="", type=str,
|
parser.add_argument(
|
||||||
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
"--labels",
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
default="",
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
type=str,
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.",
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
)
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
parser.add_argument(
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.")
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
parser.add_argument("--do_train", action="store_true",
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action="store_true",
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||||
parser.add_argument("--do_predict", action="store_true",
|
parser.add_argument(
|
||||||
help="Whether to run predictions on the test set.")
|
"--evaluate_during_training",
|
||||||
parser.add_argument("--evaluate_during_training", action="store_true",
|
action="store_true",
|
||||||
help="Whether to run evaluation during training at each logging step.")
|
help="Whether to run evaluation during training at each logging step.",
|
||||||
parser.add_argument("--do_lower_case", action="store_true",
|
)
|
||||||
help="Set this flag if you are using an uncased model.")
|
parser.add_argument(
|
||||||
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
parser.add_argument(
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
type=int,
|
||||||
help="The initial learning rate for Adam.")
|
default=1,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
help="Weight decay if we apply some.")
|
)
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument("--logging_steps", type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument("--save_steps", type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action="store_true",
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action="store_true",
|
)
|
||||||
help="Avoid using CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
parser.add_argument("--overwrite_output_dir", action="store_true",
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument("--overwrite_cache", action="store_true",
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument("--seed", type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument("--fp16", action="store_true",
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument("--fp16_opt_level", type=str, default="O1",
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
)
|
||||||
help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(
|
if (
|
||||||
args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
args.output_dir))
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -451,11 +532,19 @@ def main():
|
|||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
)
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -472,16 +561,22 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -505,7 +600,9 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
@@ -518,7 +615,9 @@ def main():
|
|||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
@@ -565,4 +664,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,11 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
||||||
from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
|
from transformers.data.metrics.squad_metrics import (
|
||||||
|
compute_predictions_logits,
|
||||||
|
compute_predictions_log_probs,
|
||||||
|
squad_evaluate,
|
||||||
|
)
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@@ -27,8 +31,7 @@ import glob
|
|||||||
import timeit
|
import timeit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -38,32 +41,47 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (
|
||||||
BertForQuestionAnswering, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
|
BertConfig,
|
||||||
XLMConfig, XLMForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
XLMTokenizer, XLNetConfig,
|
BertTokenizer,
|
||||||
|
RobertaForQuestionAnswering,
|
||||||
|
RobertaTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
|
DistilBertConfig,
|
||||||
AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer,
|
DistilBertForQuestionAnswering,
|
||||||
XLMConfig, XLMForQuestionAnswering, XLMTokenizer,
|
DistilBertTokenizer,
|
||||||
|
AlbertConfig,
|
||||||
|
AlbertForQuestionAnswering,
|
||||||
|
AlbertTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLMTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
|
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
ALL_MODELS = sum(
|
||||||
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)),
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||||
'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -85,49 +103,44 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
tb_writer = SummaryWriter()
|
tb_writer = SummaryWriter()
|
||||||
|
|
||||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||||
train_sampler = RandomSampler(
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||||
train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
|
||||||
|
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
t_total = args.max_steps
|
t_total = args.max_steps
|
||||||
args.num_train_epochs = args.max_steps // (
|
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||||
len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
|
||||||
else:
|
else:
|
||||||
t_total = len(
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(
|
{
|
||||||
nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
{'params': [p for n, p in model.named_parameters() if any(
|
"weight_decay": args.weight_decay,
|
||||||
nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters,
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
lr=args.learning_rate, eps=args.adam_epsilon)
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
scheduler.load_state_dict(torch.load(
|
|
||||||
os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
||||||
|
|
||||||
model, optimizer = amp.initialize(
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||||
model, optimizer, opt_level=args.fp16_opt_level)
|
|
||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
@@ -135,20 +148,22 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d",
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
args.per_gpu_train_batch_size)
|
logger.info(
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
args.train_batch_size
|
||||||
logger.info(" Gradient Accumulation steps = %d",
|
* args.gradient_accumulation_steps
|
||||||
args.gradient_accumulation_steps)
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
global_step = 1
|
global_step = 1
|
||||||
@@ -157,29 +172,25 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||||
epochs_trained = global_step // (len(train_dataloader) //
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
args.gradient_accumulation_steps)
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = global_step % (
|
|
||||||
len(train_dataloader) // args.gradient_accumulation_steps)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
" Continuing training from checkpoint, will skip to saved global_step")
|
|
||||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
logger.info(" Continuing training from global step %d", global_step)
|
logger.info(" Continuing training from global step %d", global_step)
|
||||||
logger.info(" Will skip the first %d steps in the first epoch",
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
steps_trained_in_current_epoch)
|
|
||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(epochs_trained, int(
|
train_iterator = trange(
|
||||||
args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
|
)
|
||||||
# Added here for reproductibility (even between python 2 and 3)
|
# Added here for reproductibility (even between python 2 and 3)
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration",
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
disable=args.local_rank not in [-1, 0])
|
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
@@ -191,18 +202,17 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': batch[0],
|
"input_ids": batch[0],
|
||||||
'attention_mask': batch[1],
|
"attention_mask": batch[1],
|
||||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
||||||
'start_positions': batch[3],
|
"start_positions": batch[3],
|
||||||
'end_positions': batch[4],
|
"end_positions": batch[4],
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({'cls_index': batch[5],
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||||
'p_mask': batch[6]})
|
|
||||||
if args.version_2_with_negative:
|
if args.version_2_with_negative:
|
||||||
inputs.update({'is_impossible': batch[7]})
|
inputs.update({"is_impossible": batch[7]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
# model outputs are always tuple in transformers (see doc)
|
# model outputs are always tuple in transformers (see doc)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
@@ -221,11 +231,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
amp.master_params(optimizer), args.max_grad_norm)
|
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
model.parameters(), args.max_grad_norm)
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
@@ -238,36 +246,27 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if args.local_rank == -1 and args.evaluate_during_training:
|
if args.local_rank == -1 and args.evaluate_during_training:
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
'eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
'lr', scheduler.get_lr()[0], global_step)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
'loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
output_dir = os.path.join(
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
args.output_dir, 'checkpoint-{}'.format(global_step))
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
# Take care of distributed/parallel training
|
# Take care of distributed/parallel training
|
||||||
model_to_save = model.module if hasattr(
|
model_to_save = model.module if hasattr(model, "module") else model
|
||||||
model, 'module') else model
|
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
output_dir, 'training_args.bin'))
|
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
torch.save(optimizer.state_dict(), os.path.join(
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
output_dir, 'optimizer.pt'))
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
torch.save(scheduler.state_dict(), os.path.join(
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||||
output_dir, 'scheduler.pt'))
|
|
||||||
logger.info(
|
|
||||||
"Saving optimizer and scheduler states to %s", output_dir)
|
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
@@ -283,8 +282,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix=""):
|
def evaluate(args, model, tokenizer, prefix=""):
|
||||||
dataset, examples, features = load_and_cache_examples(
|
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
||||||
args, tokenizer, evaluate=True, output_examples=True)
|
|
||||||
|
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
@@ -293,8 +291,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
eval_sampler = SequentialSampler(dataset)
|
eval_sampler = SequentialSampler(dataset)
|
||||||
eval_dataloader = DataLoader(
|
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
|
||||||
|
|
||||||
# multi-gpu evaluate
|
# multi-gpu evaluate
|
||||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||||
@@ -314,15 +311,15 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': batch[0],
|
"input_ids": batch[0],
|
||||||
'attention_mask': batch[1],
|
"attention_mask": batch[1],
|
||||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
||||||
}
|
}
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
|
|
||||||
# XLNet and XLM use more arguments for their predictions
|
# XLNet and XLM use more arguments for their predictions
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
|
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
@@ -342,53 +339,68 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
cls_logits = output[4]
|
cls_logits = output[4]
|
||||||
|
|
||||||
result = SquadResult(
|
result = SquadResult(
|
||||||
unique_id, start_logits, end_logits,
|
unique_id,
|
||||||
|
start_logits,
|
||||||
|
end_logits,
|
||||||
start_top_index=start_top_index,
|
start_top_index=start_top_index,
|
||||||
end_top_index=end_top_index,
|
end_top_index=end_top_index,
|
||||||
cls_logits=cls_logits
|
cls_logits=cls_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
start_logits, end_logits = output
|
start_logits, end_logits = output
|
||||||
result = SquadResult(
|
result = SquadResult(unique_id, start_logits, end_logits)
|
||||||
unique_id, start_logits, end_logits
|
|
||||||
)
|
|
||||||
|
|
||||||
all_results.append(result)
|
all_results.append(result)
|
||||||
|
|
||||||
evalTime = timeit.default_timer() - start_time
|
evalTime = timeit.default_timer() - start_time
|
||||||
logger.info(" Evaluation done in total %f secs (%f sec per example)",
|
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||||||
evalTime, evalTime / len(dataset))
|
|
||||||
|
|
||||||
# Compute predictions
|
# Compute predictions
|
||||||
output_prediction_file = os.path.join(
|
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||||
args.output_dir, "predictions_{}.json".format(prefix))
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||||
output_nbest_file = os.path.join(
|
|
||||||
args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
|
||||||
|
|
||||||
if args.version_2_with_negative:
|
if args.version_2_with_negative:
|
||||||
output_null_log_odds_file = os.path.join(
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||||
args.output_dir, "null_odds_{}.json".format(prefix))
|
|
||||||
else:
|
else:
|
||||||
output_null_log_odds_file = None
|
output_null_log_odds_file = None
|
||||||
|
|
||||||
# XLNet and XLM use a more complex post-processing procedure
|
# XLNet and XLM use a more complex post-processing procedure
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
start_n_top = model.config.start_n_top if hasattr(
|
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
|
||||||
model, "config") else model.module.config.start_n_top
|
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
|
||||||
end_n_top = model.config.end_n_top if hasattr(
|
|
||||||
model, "config") else model.module.config.end_n_top
|
|
||||||
|
|
||||||
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_log_probs(
|
||||||
args.max_answer_length, output_prediction_file,
|
examples,
|
||||||
output_nbest_file, output_null_log_odds_file,
|
features,
|
||||||
start_n_top, end_n_top,
|
all_results,
|
||||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
args.n_best_size,
|
||||||
|
args.max_answer_length,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
start_n_top,
|
||||||
|
end_n_top,
|
||||||
|
args.version_2_with_negative,
|
||||||
|
tokenizer,
|
||||||
|
args.verbose_logging,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_logits(
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
examples,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
features,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
|
all_results,
|
||||||
|
args.n_best_size,
|
||||||
|
args.max_answer_length,
|
||||||
|
args.do_lower_case,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
args.verbose_logging,
|
||||||
|
args.version_2_with_negative,
|
||||||
|
args.null_score_diff_threshold,
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
# Compute the F1 and exact scores.
|
# Compute the F1 and exact scores.
|
||||||
results = squad_evaluate(examples, predictions)
|
results = squad_evaluate(examples, predictions)
|
||||||
@@ -402,16 +414,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
input_dir = args.data_dir if args.data_dir else "."
|
input_dir = args.data_dir if args.data_dir else "."
|
||||||
cached_features_file = os.path.join(input_dir, 'cached_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
'dev' if evaluate else 'train',
|
input_dir,
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
"cached_{}_{}_{}".format(
|
||||||
str(args.max_seq_length))
|
"dev" if evaluate else "train",
|
||||||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
|
str(args.max_seq_length),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init features and dataset from cache if it exists
|
# Init features and dataset from cache if it exists
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s",
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
cached_features_file)
|
|
||||||
features_and_dataset = torch.load(cached_features_file)
|
features_and_dataset = torch.load(cached_features_file)
|
||||||
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
||||||
else:
|
else:
|
||||||
@@ -421,16 +435,13 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
try:
|
try:
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||||
"If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
|
||||||
|
|
||||||
if args.version_2_with_negative:
|
if args.version_2_with_negative:
|
||||||
logger.warn(
|
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
|
||||||
"tensorflow_datasets does not handle version 2 of SQuAD.")
|
|
||||||
|
|
||||||
tfds_examples = tfds.load("squad")
|
tfds_examples = tfds.load("squad")
|
||||||
examples = SquadV1Processor().get_examples_from_dataset(
|
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
||||||
tfds_examples, evaluate=evaluate)
|
|
||||||
else:
|
else:
|
||||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||||
if evaluate:
|
if evaluate:
|
||||||
@@ -445,15 +456,13 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate,
|
is_training=not evaluate,
|
||||||
return_dataset='pt',
|
return_dataset="pt",
|
||||||
threads=args.threads,
|
threads=args.threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s",
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
cached_features_file)
|
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
||||||
torch.save({"features": features, "dataset": dataset},
|
|
||||||
cached_features_file)
|
|
||||||
|
|
||||||
if args.local_rank == 0 and not evaluate:
|
if args.local_rank == 0 and not evaluate:
|
||||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
@@ -468,140 +477,232 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
# Required parameters
|
# Required parameters
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
"--model_type",
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
type=str,
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
required=True,
|
||||||
help="The output directory where the model checkpoints and predictions will be written.")
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
# Other parameters
|
# Other parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the .json files for the task." +
|
"--data_dir",
|
||||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
default=None,
|
||||||
parser.add_argument("--train_file", default=None, type=str,
|
type=str,
|
||||||
help="The input training file. If a data dir is specified, will look for the file there" +
|
help="The input data dir. Should contain the .json files for the task."
|
||||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
parser.add_argument("--predict_file", default=None, type=str,
|
)
|
||||||
help="The input evaluation file. If a data dir is specified, will look for the file there" +
|
parser.add_argument(
|
||||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
"--train_file",
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
default=None,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
type=str,
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
help="The input training file. If a data dir is specified, will look for the file there"
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
)
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
parser.add_argument(
|
||||||
|
"--predict_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
||||||
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
parser.add_argument(
|
||||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
"--version_2_with_negative",
|
||||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
action="store_true",
|
||||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--null_score_diff_threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=384,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||||
parser.add_argument("--doc_stride", default=128, type=int,
|
)
|
||||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
parser.add_argument(
|
||||||
parser.add_argument("--max_query_length", default=64, type=int,
|
"--doc_stride",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_query_length",
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||||
"be truncated to this length.")
|
"be truncated to this length.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
help="Rul evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="The initial learning rate for Adam.")
|
parser.add_argument(
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
"--gradient_accumulation_steps",
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
type=int,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
default=1,
|
||||||
help="Weight decay if we apply some.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
)
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
parser.add_argument("--n_best_size", default=20, type=int,
|
type=int,
|
||||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_best_size",
|
||||||
|
default=20,
|
||||||
|
type=int,
|
||||||
|
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_answer_length",
|
||||||
|
default=30,
|
||||||
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another.")
|
"and end predictions are not conditioned on one another.",
|
||||||
parser.add_argument("--verbose_logging", action='store_true',
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose_logging",
|
||||||
|
action="store_true",
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help="If true, all of the warnings related to data processing will be printed. "
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Whether not to use CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||||
help="local_rank for distributed training on gpus")
|
parser.add_argument(
|
||||||
parser.add_argument('--fp16', action='store_true',
|
"--fp16",
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
action="store_true",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
)
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
|
||||||
parser.add_argument('--threads', type=int, default=1, help='multiple threads for converting example to features')
|
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
address=(args.server_ip, args.server_port), redirect_output=True)
|
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
|
|
||||||
# Setup CUDA, GPU & distributed training
|
# Setup CUDA, GPU & distributed training
|
||||||
if args.local_rank == -1 or args.no_cuda:
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
device = torch.device(
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
||||||
args.n_gpu = torch.cuda.device_count()
|
args.n_gpu = torch.cuda.device_count()
|
||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -613,16 +714,21 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
from_tf=bool(
|
model = model_class.from_pretrained(
|
||||||
'.ckpt' in args.model_name_or_path),
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
# Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -638,18 +744,16 @@ def main():
|
|||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
import apex
|
import apex
|
||||||
apex.amp.register_half_function(torch, 'einsum')
|
|
||||||
|
apex.amp.register_half_function(torch, "einsum")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(
|
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||||
args, tokenizer, evaluate=False, output_examples=False)
|
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s",
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
global_step, tr_loss)
|
|
||||||
|
|
||||||
# Save the trained model and the tokenizer
|
# Save the trained model and the tokenizer
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
@@ -661,18 +765,16 @@ def main():
|
|||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
# Take care of distributed/parallel training
|
# Take care of distributed/parallel training
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model
|
model_to_save = model.module if hasattr(model, "module") else model
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(args.output_dir, force_download=True)
|
||||||
args.output_dir, force_download=True)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
|
||||||
args.output_dir, do_lower_case=args.do_lower_case)
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
@@ -682,7 +784,10 @@ def main():
|
|||||||
logger.info("Loading checkpoints saved during training for evaluation")
|
logger.info("Loading checkpoints saved during training for evaluation")
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c)
|
||||||
|
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||||
else:
|
else:
|
||||||
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||||||
@@ -692,17 +797,14 @@ def main():
|
|||||||
|
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split(
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
'-')[-1] if len(checkpoints) > 1 else ""
|
model = model_class.from_pretrained(checkpoint, force_download=True)
|
||||||
model = model_class.from_pretrained(
|
|
||||||
checkpoint, force_download=True)
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||||
|
|
||||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v)
|
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||||
for k, v in result.items())
|
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
logger.info("Results: {}".format(results))
|
logger.info("Results: {}".format(results))
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets
|
import tensorflow_datasets
|
||||||
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig, glue_convert_examples_to_features, BertForSequenceClassification, glue_processors
|
from transformers import (
|
||||||
|
BertTokenizer,
|
||||||
|
TFBertForSequenceClassification,
|
||||||
|
BertConfig,
|
||||||
|
glue_convert_examples_to_features,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
glue_processors,
|
||||||
|
)
|
||||||
|
|
||||||
# script parameters
|
# script parameters
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
@@ -27,21 +34,21 @@ tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
|||||||
|
|
||||||
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
|
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
|
||||||
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
|
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased', config=config)
|
model = TFBertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
|
||||||
|
|
||||||
# Load dataset via TensorFlow Datasets
|
# Load dataset via TensorFlow Datasets
|
||||||
data, info = tensorflow_datasets.load(f'glue/{TFDS_TASK}', with_info=True)
|
data, info = tensorflow_datasets.load(f"glue/{TFDS_TASK}", with_info=True)
|
||||||
train_examples = info.splits['train'].num_examples
|
train_examples = info.splits["train"].num_examples
|
||||||
|
|
||||||
# MNLI expects either validation_matched or validation_mismatched
|
# MNLI expects either validation_matched or validation_mismatched
|
||||||
valid_examples = info.splits['validation'].num_examples
|
valid_examples = info.splits["validation"].num_examples
|
||||||
|
|
||||||
# Prepare dataset for GLUE as a tf.data.Dataset instance
|
# Prepare dataset for GLUE as a tf.data.Dataset instance
|
||||||
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, TASK)
|
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, 128, TASK)
|
||||||
|
|
||||||
# MNLI expects either validation_matched or validation_mismatched
|
# MNLI expects either validation_matched or validation_mismatched
|
||||||
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, TASK)
|
valid_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, 128, TASK)
|
||||||
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
||||||
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||||
|
|
||||||
@@ -49,7 +56,7 @@ valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
|||||||
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||||
if USE_AMP:
|
if USE_AMP:
|
||||||
# loss scaling is currently required when using mixed precision
|
# loss scaling is currently required when using mixed precision
|
||||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
|
||||||
|
|
||||||
|
|
||||||
if num_labels == 1:
|
if num_labels == 1:
|
||||||
@@ -57,37 +64,42 @@ if num_labels == 1:
|
|||||||
else:
|
else:
|
||||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||||
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||||
|
|
||||||
# Train and evaluate using tf.keras.Model.fit()
|
# Train and evaluate using tf.keras.Model.fit()
|
||||||
train_steps = train_examples // BATCH_SIZE
|
train_steps = train_examples // BATCH_SIZE
|
||||||
valid_steps = valid_examples // EVAL_BATCH_SIZE
|
valid_steps = valid_examples // EVAL_BATCH_SIZE
|
||||||
|
|
||||||
history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
|
history = model.fit(
|
||||||
validation_data=valid_dataset, validation_steps=valid_steps)
|
train_dataset,
|
||||||
|
epochs=EPOCHS,
|
||||||
|
steps_per_epoch=train_steps,
|
||||||
|
validation_data=valid_dataset,
|
||||||
|
validation_steps=valid_steps,
|
||||||
|
)
|
||||||
|
|
||||||
# Save TF2 model
|
# Save TF2 model
|
||||||
os.makedirs('./save/', exist_ok=True)
|
os.makedirs("./save/", exist_ok=True)
|
||||||
model.save_pretrained('./save/')
|
model.save_pretrained("./save/")
|
||||||
|
|
||||||
if TASK == "mrpc":
|
if TASK == "mrpc":
|
||||||
# Load the TensorFlow model in PyTorch for inspection
|
# Load the TensorFlow model in PyTorch for inspection
|
||||||
# This is to demo the interoperability between the two frameworks, you don't have to
|
# This is to demo the interoperability between the two frameworks, you don't have to
|
||||||
# do this in real life (you can run the inference on the TF model).
|
# do this in real life (you can run the inference on the TF model).
|
||||||
pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)
|
pytorch_model = BertForSequenceClassification.from_pretrained("./save/", from_tf=True)
|
||||||
|
|
||||||
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
|
# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
|
||||||
sentence_0 = 'This research was consistent with his findings.'
|
sentence_0 = "This research was consistent with his findings."
|
||||||
sentence_1 = 'His findings were compatible with this research.'
|
sentence_1 = "His findings were compatible with this research."
|
||||||
sentence_2 = 'His findings were not compatible with this research.'
|
sentence_2 = "His findings were not compatible with this research."
|
||||||
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
|
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors="pt")
|
||||||
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
|
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors="pt")
|
||||||
|
|
||||||
del inputs_1["special_tokens_mask"]
|
del inputs_1["special_tokens_mask"]
|
||||||
del inputs_2["special_tokens_mask"]
|
del inputs_2["special_tokens_mask"]
|
||||||
|
|
||||||
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
||||||
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
||||||
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
|
print("sentence_1 is", "a paraphrase" if pred_1 else "not a paraphrase", "of sentence_0")
|
||||||
print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')
|
print("sentence_2 is", "a paraphrase" if pred_2 else "not a paraphrase", "of sentence_0")
|
||||||
|
|||||||
@@ -21,189 +21,156 @@ from absl import app
|
|||||||
|
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
|
||||||
())
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
||||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
||||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer)
|
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"data_dir", None,
|
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
|
||||||
"The input data dir. Should contain the .conll files (or other data files) "
|
)
|
||||||
"for the task.")
|
|
||||||
|
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"model_type", None,
|
"model_name_or_path",
|
||||||
"Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
None,
|
||||||
|
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
|
||||||
|
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"model_name_or_path", None,
|
"labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
|
||||||
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
)
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string("config_name", "", "Pretrained config name or path if not the same as model_name")
|
||||||
"output_dir", None,
|
|
||||||
"The output directory where the model checkpoints will be written.")
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string("tokenizer_name", "", "Pretrained tokenizer name or path if not the same as model_name")
|
||||||
"labels", "",
|
|
||||||
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string("cache_dir", "", "Where do you want to store the pre-trained models downloaded from s3")
|
||||||
"config_name", "",
|
|
||||||
"Pretrained config name or path if not the same as model_name")
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
|
||||||
"tokenizer_name", "",
|
|
||||||
"Pretrained tokenizer name or path if not the same as model_name")
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
|
||||||
"cache_dir", "",
|
|
||||||
"Where do you want to store the pre-trained models downloaded from s3")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
"max_seq_length", 128,
|
"max_seq_length",
|
||||||
|
128,
|
||||||
"The maximum total input sentence length after tokenization. "
|
"The maximum total input sentence length after tokenization. "
|
||||||
"Sequences longer than this will be truncated, sequences shorter "
|
"Sequences longer than this will be truncated, sequences shorter "
|
||||||
"will be padded.")
|
"will be padded.",
|
||||||
|
)
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"tpu", None,
|
"tpu",
|
||||||
|
None,
|
||||||
"The Cloud TPU to use for training. This should be either the name "
|
"The Cloud TPU to use for training. This should be either the name "
|
||||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||||
"url.")
|
"url.",
|
||||||
|
)
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
|
||||||
"num_tpu_cores", 8,
|
|
||||||
"Total number of TPU cores to use.")
|
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
|
||||||
|
|
||||||
|
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
|
||||||
|
|
||||||
|
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean(
|
||||||
"do_train", False,
|
"evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
|
||||||
"Whether to run training.")
|
)
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
|
||||||
"do_eval", False,
|
|
||||||
"Whether to run eval on the dev set.")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
|
||||||
"do_predict", False,
|
|
||||||
"Whether to run predictions on the test set.")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
|
||||||
"evaluate_during_training", False,
|
|
||||||
"Whether to run evaluation during training at each logging step.")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
|
||||||
"do_lower_case", False,
|
|
||||||
"Set this flag if you are using an uncased model.")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
"per_device_train_batch_size", 8,
|
"gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
|
||||||
"Batch size per GPU/CPU/TPU for training.")
|
)
|
||||||
|
|
||||||
|
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
||||||
|
|
||||||
|
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
|
||||||
|
|
||||||
|
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
|
||||||
|
|
||||||
|
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
"per_device_eval_batch_size", 8,
|
"max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
|
||||||
"Batch size per GPU/CPU/TPU for evaluation.")
|
)
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
|
||||||
"gradient_accumulation_steps", 1,
|
|
||||||
"Number of updates steps to accumulate before performing a backward/update pass.")
|
|
||||||
|
|
||||||
flags.DEFINE_float(
|
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
|
||||||
"learning_rate", 5e-5,
|
|
||||||
"The initial learning rate for Adam.")
|
|
||||||
|
|
||||||
flags.DEFINE_float(
|
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
|
||||||
"weight_decay", 0.0,
|
|
||||||
"Weight decay if we apply some.")
|
|
||||||
|
|
||||||
flags.DEFINE_float(
|
|
||||||
"adam_epsilon", 1e-8,
|
|
||||||
"Epsilon for Adam optimizer.")
|
|
||||||
|
|
||||||
flags.DEFINE_float(
|
|
||||||
"max_grad_norm", 1.0,
|
|
||||||
"Max gradient norm.")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
"num_train_epochs", 3,
|
|
||||||
"Total number of training epochs to perform.")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
"max_steps", -1,
|
|
||||||
"If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
"warmup_steps", 0,
|
|
||||||
"Linear warmup over warmup_steps.")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
"logging_steps", 50,
|
|
||||||
"Log every X updates steps.")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
"save_steps", 50,
|
|
||||||
"Save checkpoint every X updates steps.")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean(
|
||||||
"eval_all_checkpoints", False,
|
"eval_all_checkpoints",
|
||||||
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
False,
|
||||||
|
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
|
)
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA when available")
|
||||||
"no_cuda", False,
|
|
||||||
"Avoid using CUDA when available")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
|
||||||
"overwrite_output_dir", False,
|
|
||||||
"Overwrite the content of the output directory")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
|
||||||
"overwrite_cache", False,
|
|
||||||
"Overwrite the cached training and evaluation sets")
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer("seed", 42, "random seed for initialization")
|
||||||
"seed", 42,
|
|
||||||
"random seed for initialization")
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
|
||||||
"fp16", False,
|
|
||||||
"Whether to use 16-bit (mixed) precision instead of 32-bit")
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
"gpus", "0",
|
"gpus",
|
||||||
|
"0",
|
||||||
"Comma separated list of gpus devices. If only one, switch to single "
|
"Comma separated list of gpus devices. If only one, switch to single "
|
||||||
"gpu strategy, if None takes all the gpus available.")
|
"gpu strategy, if None takes all the gpus available.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id):
|
def train(
|
||||||
if args['max_steps'] > 0:
|
args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
|
||||||
num_train_steps = args['max_steps'] * args['gradient_accumulation_steps']
|
):
|
||||||
args['num_train_epochs'] = 1
|
if args["max_steps"] > 0:
|
||||||
|
num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
|
||||||
|
args["num_train_epochs"] = 1
|
||||||
else:
|
else:
|
||||||
num_train_steps = math.ceil(num_train_examples / train_batch_size) // args['gradient_accumulation_steps'] * args['num_train_epochs']
|
num_train_steps = (
|
||||||
|
math.ceil(num_train_examples / train_batch_size)
|
||||||
|
// args["gradient_accumulation_steps"]
|
||||||
|
* args["num_train_epochs"]
|
||||||
|
)
|
||||||
|
|
||||||
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||||
optimizer = create_optimizer(args['learning_rate'], num_train_steps, args['warmup_steps'])
|
optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
|
||||||
|
|
||||||
if args['fp16']:
|
if args["fp16"]:
|
||||||
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
|
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
|
||||||
|
|
||||||
loss_metric = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
|
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
|
||||||
gradient_accumulator = GradientAccumulator()
|
gradient_accumulator = GradientAccumulator()
|
||||||
|
|
||||||
logging.info("***** Running training *****")
|
logging.info("***** Running training *****")
|
||||||
logging.info(" Num examples = %d", num_train_examples)
|
logging.info(" Num examples = %d", num_train_examples)
|
||||||
logging.info(" Num Epochs = %d", args['num_train_epochs'])
|
logging.info(" Num Epochs = %d", args["num_train_epochs"])
|
||||||
logging.info(" Instantaneous batch size per device = %d", args['per_device_train_batch_size'])
|
logging.info(" Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
|
||||||
logging.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logging.info(
|
||||||
train_batch_size * args['gradient_accumulation_steps'])
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
logging.info(" Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
|
train_batch_size * args["gradient_accumulation_steps"],
|
||||||
|
)
|
||||||
|
logging.info(" Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
|
||||||
logging.info(" Total training steps = %d", num_train_steps)
|
logging.info(" Total training steps = %d", num_train_steps)
|
||||||
|
|
||||||
model.summary()
|
model.summary()
|
||||||
@@ -214,26 +181,28 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
|||||||
|
|
||||||
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
||||||
if gradient is not None:
|
if gradient is not None:
|
||||||
scaled_gradient = gradient / (args['n_device'] * args['gradient_accumulation_steps'])
|
scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
|
||||||
grads_and_vars.append((scaled_gradient, variable))
|
grads_and_vars.append((scaled_gradient, variable))
|
||||||
else:
|
else:
|
||||||
grads_and_vars.append((gradient, variable))
|
grads_and_vars.append((gradient, variable))
|
||||||
|
|
||||||
optimizer.apply_gradients(grads_and_vars, args['max_grad_norm'])
|
optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
|
||||||
gradient_accumulator.reset()
|
gradient_accumulator.reset()
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def train_step(train_features, train_labels):
|
def train_step(train_features, train_labels):
|
||||||
def step_fn(train_features, train_labels):
|
def step_fn(train_features, train_labels):
|
||||||
inputs = {'attention_mask': train_features['input_mask'], 'training': True}
|
inputs = {"attention_mask": train_features["input_mask"], "training": True}
|
||||||
|
|
||||||
if args['model_type'] != "distilbert":
|
if args["model_type"] != "distilbert":
|
||||||
inputs["token_type_ids"] = train_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
inputs["token_type_ids"] = (
|
||||||
|
train_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
|
||||||
|
)
|
||||||
|
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
logits = model(train_features['input_ids'], **inputs)[0]
|
logits = model(train_features["input_ids"], **inputs)[0]
|
||||||
logits = tf.reshape(logits, (-1, len(labels) + 1))
|
logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||||
active_loss = tf.reshape(train_features['input_mask'], (-1,))
|
active_loss = tf.reshape(train_features["input_mask"], (-1,))
|
||||||
active_logits = tf.boolean_mask(logits, active_loss)
|
active_logits = tf.boolean_mask(logits, active_loss)
|
||||||
train_labels = tf.reshape(train_labels, (-1,))
|
train_labels = tf.reshape(train_labels, (-1,))
|
||||||
active_labels = tf.boolean_mask(train_labels, active_loss)
|
active_labels = tf.boolean_mask(train_labels, active_loss)
|
||||||
@@ -251,29 +220,35 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
|||||||
return mean_loss
|
return mean_loss
|
||||||
|
|
||||||
current_time = datetime.datetime.now()
|
current_time = datetime.datetime.now()
|
||||||
train_iterator = master_bar(range(args['num_train_epochs']))
|
train_iterator = master_bar(range(args["num_train_epochs"]))
|
||||||
global_step = 0
|
global_step = 0
|
||||||
logging_loss = 0.0
|
logging_loss = 0.0
|
||||||
|
|
||||||
for epoch in train_iterator:
|
for epoch in train_iterator:
|
||||||
epoch_iterator = progress_bar(train_dataset, total=num_train_steps, parent=train_iterator, display=args['n_device'] > 1)
|
epoch_iterator = progress_bar(
|
||||||
|
train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
|
||||||
|
)
|
||||||
step = 1
|
step = 1
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
for train_features, train_labels in epoch_iterator:
|
for train_features, train_labels in epoch_iterator:
|
||||||
loss = train_step(train_features, train_labels)
|
loss = train_step(train_features, train_labels)
|
||||||
|
|
||||||
if step % args['gradient_accumulation_steps'] == 0:
|
if step % args["gradient_accumulation_steps"] == 0:
|
||||||
strategy.experimental_run_v2(apply_gradients)
|
strategy.experimental_run_v2(apply_gradients)
|
||||||
|
|
||||||
loss_metric(loss)
|
loss_metric(loss)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
|
if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args['n_device'] == 1 and args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
args["n_device"] == 1 and args["evaluate_during_training"]
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
|
y_true, y_pred, eval_loss = evaluate(
|
||||||
|
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||||
|
)
|
||||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||||
|
|
||||||
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
||||||
@@ -294,16 +269,18 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
|||||||
|
|
||||||
with writer.as_default():
|
with writer.as_default():
|
||||||
tf.summary.scalar("lr", learning_rate, global_step)
|
tf.summary.scalar("lr", learning_rate, global_step)
|
||||||
tf.summary.scalar("loss", (loss_metric.result() - logging_loss) / args['logging_steps'], global_step)
|
tf.summary.scalar(
|
||||||
|
"loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
|
||||||
|
)
|
||||||
|
|
||||||
logging_loss = loss_metric.result()
|
logging_loss = loss_metric.result()
|
||||||
|
|
||||||
with writer.as_default():
|
with writer.as_default():
|
||||||
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
||||||
|
|
||||||
if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
|
if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args['output_dir'], "checkpoint-{}".format(global_step))
|
output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
@@ -311,10 +288,10 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
|||||||
model.save_pretrained(output_dir)
|
model.save_pretrained(output_dir)
|
||||||
logging.info("Saving model checkpoint to %s", output_dir)
|
logging.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
train_iterator.child.comment = f'loss : {loss_metric.result()}'
|
train_iterator.child.comment = f"loss : {loss_metric.result()}"
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
train_iterator.write(f'loss epoch {epoch + 1}: {loss_metric.result()}')
|
train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
|
||||||
|
|
||||||
loss_metric.reset_states()
|
loss_metric.reset_states()
|
||||||
|
|
||||||
@@ -322,13 +299,15 @@ def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, l
|
|||||||
|
|
||||||
|
|
||||||
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
||||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||||
eval_dataset, size = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode)
|
eval_dataset, size = load_and_cache_examples(
|
||||||
|
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
|
||||||
|
)
|
||||||
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
||||||
preds = None
|
preds = None
|
||||||
num_eval_steps = math.ceil(size / eval_batch_size)
|
num_eval_steps = math.ceil(size / eval_batch_size)
|
||||||
master = master_bar(range(1))
|
master = master_bar(range(1))
|
||||||
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args['n_device'] > 1)
|
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
|
||||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
@@ -337,15 +316,17 @@ def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode)
|
|||||||
logging.info(" Batch size = %d", eval_batch_size)
|
logging.info(" Batch size = %d", eval_batch_size)
|
||||||
|
|
||||||
for eval_features, eval_labels in eval_iterator:
|
for eval_features, eval_labels in eval_iterator:
|
||||||
inputs = {'attention_mask': eval_features['input_mask'], 'training': False}
|
inputs = {"attention_mask": eval_features["input_mask"], "training": False}
|
||||||
|
|
||||||
if args['model_type'] != "distilbert":
|
if args["model_type"] != "distilbert":
|
||||||
inputs["token_type_ids"] = eval_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
inputs["token_type_ids"] = (
|
||||||
|
eval_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
|
||||||
|
)
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
logits = model(eval_features['input_ids'], **inputs)[0]
|
logits = model(eval_features["input_ids"], **inputs)[0]
|
||||||
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
|
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||||
active_loss = tf.reshape(eval_features['input_mask'], (-1,))
|
active_loss = tf.reshape(eval_features["input_mask"], (-1,))
|
||||||
active_logits = tf.boolean_mask(tmp_logits, active_loss)
|
active_logits = tf.boolean_mask(tmp_logits, active_loss)
|
||||||
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
|
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
|
||||||
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
|
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
|
||||||
@@ -384,11 +365,11 @@ def load_cache(cached_file, max_seq_length):
|
|||||||
def _decode_record(record):
|
def _decode_record(record):
|
||||||
example = tf.io.parse_single_example(record, name_to_features)
|
example = tf.io.parse_single_example(record, name_to_features)
|
||||||
features = {}
|
features = {}
|
||||||
features['input_ids'] = example['input_ids']
|
features["input_ids"] = example["input_ids"]
|
||||||
features['input_mask'] = example['input_mask']
|
features["input_mask"] = example["input_mask"]
|
||||||
features['segment_ids'] = example['segment_ids']
|
features["segment_ids"] = example["segment_ids"]
|
||||||
|
|
||||||
return features, example['label_ids']
|
return features, example["label_ids"]
|
||||||
|
|
||||||
d = tf.data.TFRecordDataset(cached_file)
|
d = tf.data.TFRecordDataset(cached_file)
|
||||||
d = d.map(_decode_record, num_parallel_calls=4)
|
d = d.map(_decode_record, num_parallel_calls=4)
|
||||||
@@ -422,39 +403,46 @@ def save_cache(features, cached_features_file):
|
|||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
||||||
drop_remainder = True if args['tpu'] or mode == 'train' else False
|
drop_remainder = True if args["tpu"] or mode == "train" else False
|
||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(args['data_dir'], "cached_{}_{}_{}.tf_record".format(mode,
|
cached_features_file = os.path.join(
|
||||||
list(filter(None, args['model_name_or_path'].split("/"))).pop(),
|
args["data_dir"],
|
||||||
str(args['max_seq_length'])))
|
"cached_{}_{}_{}.tf_record".format(
|
||||||
if os.path.exists(cached_features_file) and not args['overwrite_cache']:
|
mode, list(filter(None, args["model_name_or_path"].split("/"))).pop(), str(args["max_seq_length"])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
|
||||||
logging.info("Loading features from cached file %s", cached_features_file)
|
logging.info("Loading features from cached file %s", cached_features_file)
|
||||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
dataset, size = load_cache(cached_features_file, args["max_seq_length"])
|
||||||
else:
|
else:
|
||||||
logging.info("Creating features from dataset file at %s", args['data_dir'])
|
logging.info("Creating features from dataset file at %s", args["data_dir"])
|
||||||
examples = read_examples_from_file(args['data_dir'], mode)
|
examples = read_examples_from_file(args["data_dir"], mode)
|
||||||
features = convert_examples_to_features(examples, labels, args['max_seq_length'], tokenizer,
|
features = convert_examples_to_features(
|
||||||
cls_token_at_end=bool(args['model_type'] in ["xlnet"]),
|
examples,
|
||||||
|
labels,
|
||||||
|
args["max_seq_length"],
|
||||||
|
tokenizer,
|
||||||
|
cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
|
||||||
# xlnet has a cls token at the end
|
# xlnet has a cls token at the end
|
||||||
cls_token=tokenizer.cls_token,
|
cls_token=tokenizer.cls_token,
|
||||||
cls_token_segment_id=2 if args['model_type'] in ["xlnet"] else 0,
|
cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
|
||||||
sep_token=tokenizer.sep_token,
|
sep_token=tokenizer.sep_token,
|
||||||
sep_token_extra=bool(args['model_type'] in ["roberta"]),
|
sep_token_extra=bool(args["model_type"] in ["roberta"]),
|
||||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||||
pad_on_left=bool(args['model_type'] in ["xlnet"]),
|
pad_on_left=bool(args["model_type"] in ["xlnet"]),
|
||||||
# pad on the left for xlnet
|
# pad on the left for xlnet
|
||||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||||
pad_token_segment_id=4 if args['model_type'] in ["xlnet"] else 0,
|
pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
|
||||||
pad_token_label_id=pad_token_label_id
|
pad_token_label_id=pad_token_label_id,
|
||||||
)
|
)
|
||||||
logging.info("Saving features into cached file %s", cached_features_file)
|
logging.info("Saving features into cached file %s", cached_features_file)
|
||||||
save_cache(features, cached_features_file)
|
save_cache(features, cached_features_file)
|
||||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
dataset, size = load_cache(cached_features_file, args["max_seq_length"])
|
||||||
|
|
||||||
if mode == 'train':
|
if mode == "train":
|
||||||
dataset = dataset.repeat()
|
dataset = dataset.repeat()
|
||||||
dataset = dataset.shuffle(buffer_size=8192, seed=args['seed'])
|
dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
|
||||||
|
|
||||||
dataset = dataset.batch(batch_size, drop_remainder)
|
dataset = dataset.batch(batch_size, drop_remainder)
|
||||||
dataset = dataset.prefetch(buffer_size=batch_size)
|
dataset = dataset.prefetch(buffer_size=batch_size)
|
||||||
@@ -466,83 +454,117 @@ def main(_):
|
|||||||
logging.set_verbosity(logging.INFO)
|
logging.set_verbosity(logging.INFO)
|
||||||
args = flags.FLAGS.flag_values_dict()
|
args = flags.FLAGS.flag_values_dict()
|
||||||
|
|
||||||
if os.path.exists(args['output_dir']) and os.listdir(
|
if (
|
||||||
args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
|
os.path.exists(args["output_dir"])
|
||||||
|
and os.listdir(args["output_dir"])
|
||||||
|
and args["do_train"]
|
||||||
|
and not args["overwrite_output_dir"]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
args['output_dir']))
|
args["output_dir"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if args['fp16']:
|
if args["fp16"]:
|
||||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
||||||
|
|
||||||
if args['tpu']:
|
if args["tpu"]:
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args['tpu'])
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
|
||||||
tf.config.experimental_connect_to_cluster(resolver)
|
tf.config.experimental_connect_to_cluster(resolver)
|
||||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||||
args['n_device'] = args['num_tpu_cores']
|
args["n_device"] = args["num_tpu_cores"]
|
||||||
elif len(args['gpus'].split(',')) > 1:
|
elif len(args["gpus"].split(",")) > 1:
|
||||||
args['n_device'] = len([f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
||||||
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
|
||||||
elif args['no_cuda']:
|
elif args["no_cuda"]:
|
||||||
args['n_device'] = 1
|
args["n_device"] = 1
|
||||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||||
else:
|
else:
|
||||||
args['n_device'] = len(args['gpus'].split(','))
|
args["n_device"] = len(args["gpus"].split(","))
|
||||||
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args['gpus'].split(',')[0])
|
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
|
||||||
|
|
||||||
logging.warning("n_device: %s, distributed training: %s, 16-bits training: %s",
|
logging.warning(
|
||||||
args['n_device'], bool(args['n_device'] > 1), args['fp16'])
|
"n_device: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args["n_device"],
|
||||||
|
bool(args["n_device"] > 1),
|
||||||
|
args["fp16"],
|
||||||
|
)
|
||||||
|
|
||||||
labels = get_labels(args['labels'])
|
labels = get_labels(args["labels"])
|
||||||
num_labels = len(labels) + 1
|
num_labels = len(labels) + 1
|
||||||
pad_token_label_id = 0
|
pad_token_label_id = 0
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
|
||||||
config = config_class.from_pretrained(args['config_name'] if args['config_name'] else args['model_name_or_path'],
|
config = config_class.from_pretrained(
|
||||||
|
args["config_name"] if args["config_name"] else args["model_name_or_path"],
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Training/evaluation parameters %s", args)
|
logging.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args['do_train']:
|
if args["do_train"]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args['tokenizer_name'] if args['tokenizer_name'] else args['model_name_or_path'],
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
do_lower_case=args['do_lower_case'],
|
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
|
||||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
do_lower_case=args["do_lower_case"],
|
||||||
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||||
|
)
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model = model_class.from_pretrained(args['model_name_or_path'],
|
model = model_class.from_pretrained(
|
||||||
from_pt=bool(".bin" in args['model_name_or_path']),
|
args["model_name_or_path"],
|
||||||
|
from_pt=bool(".bin" in args["model_name_or_path"]),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
|
||||||
|
)
|
||||||
model.layers[-1].activation = tf.keras.activations.softmax
|
model.layers[-1].activation = tf.keras.activations.softmax
|
||||||
|
|
||||||
train_batch_size = args['per_device_train_batch_size'] * args['n_device']
|
train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
|
||||||
train_dataset, num_train_examples = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train")
|
train_dataset, num_train_examples = load_and_cache_examples(
|
||||||
|
args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
|
||||||
|
)
|
||||||
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
||||||
train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id)
|
train(
|
||||||
|
args,
|
||||||
|
strategy,
|
||||||
|
train_dataset,
|
||||||
|
tokenizer,
|
||||||
|
model,
|
||||||
|
num_train_examples,
|
||||||
|
labels,
|
||||||
|
train_batch_size,
|
||||||
|
pad_token_label_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not os.path.exists(args['output_dir']):
|
if not os.path.exists(args["output_dir"]):
|
||||||
os.makedirs(args['output_dir'])
|
os.makedirs(args["output_dir"])
|
||||||
|
|
||||||
logging.info("Saving model to %s", args['output_dir'])
|
logging.info("Saving model to %s", args["output_dir"])
|
||||||
|
|
||||||
model.save_pretrained(args['output_dir'])
|
model.save_pretrained(args["output_dir"])
|
||||||
tokenizer.save_pretrained(args['output_dir'])
|
tokenizer.save_pretrained(args["output_dir"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if args['do_eval']:
|
if args["do_eval"]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if args['eval_all_checkpoints']:
|
if args["eval_all_checkpoints"]:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c)
|
||||||
|
for c in sorted(
|
||||||
|
glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
|
||||||
|
key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
|
||||||
if len(checkpoints) == 0:
|
if len(checkpoints) == 0:
|
||||||
checkpoints.append(args['output_dir'])
|
checkpoints.append(args["output_dir"])
|
||||||
|
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||||
@@ -550,13 +572,15 @@ def main(_):
|
|||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
|
|
||||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
y_true, y_pred, eval_loss = evaluate(
|
||||||
|
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
|
||||||
|
)
|
||||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||||
|
|
||||||
if global_step:
|
if global_step:
|
||||||
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
||||||
|
|
||||||
output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
|
output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")
|
||||||
|
|
||||||
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
||||||
for res in results:
|
for res in results:
|
||||||
@@ -572,14 +596,16 @@ def main(_):
|
|||||||
writer.write(report)
|
writer.write(report)
|
||||||
writer.write("\n")
|
writer.write("\n")
|
||||||
|
|
||||||
if args['do_predict']:
|
if args["do_predict"]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
|
||||||
model = model_class.from_pretrained(args['output_dir'])
|
model = model_class.from_pretrained(args["output_dir"])
|
||||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
|
||||||
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
|
predict_dataset, _ = load_and_cache_examples(
|
||||||
|
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
|
||||||
|
)
|
||||||
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||||
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
|
output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
|
||||||
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
|
output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
|
||||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||||
|
|
||||||
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
||||||
@@ -591,7 +617,7 @@ def main(_):
|
|||||||
writer.write("\n\nloss = " + str(pred_loss))
|
writer.write("\n\nloss = " + str(pred_loss))
|
||||||
|
|
||||||
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
||||||
with tf.io.gfile.GFile(os.path.join(args['data_dir'], "test.txt"), "r") as f:
|
with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
|
||||||
example_id = 0
|
example_id = 0
|
||||||
|
|
||||||
for line in f:
|
for line in f:
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -37,10 +36,18 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME,
|
from transformers import (
|
||||||
BertConfig, BertForSequenceClassification, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
BertConfig,
|
||||||
DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
BertForSequenceClassification,
|
||||||
|
BertTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForSequenceClassification,
|
||||||
|
XLMTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
@@ -52,12 +59,14 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ())
|
ALL_MODELS = sum(
|
||||||
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ()
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -85,19 +94,26 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
@@ -112,17 +128,21 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -132,7 +152,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
@@ -143,7 +163,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(
|
||||||
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
|
)
|
||||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
@@ -155,11 +177,11 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
'attention_mask': batch[1],
|
if args.model_type != "distilbert":
|
||||||
'labels': batch[3]}
|
inputs["token_type_ids"] = (
|
||||||
if args.model_type != 'distilbert':
|
batch[2] if args.model_type in ["bert"] else None
|
||||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
|
) # XLM and DistilBERT don't use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
|
|
||||||
@@ -188,28 +210,32 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -258,11 +284,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||||
'attention_mask': batch[1],
|
if args.model_type != "distilbert":
|
||||||
'labels': batch[3]}
|
inputs["token_type_ids"] = (
|
||||||
if args.model_type != 'distilbert':
|
batch[2] if args.model_type in ["bert"] else None
|
||||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
|
) # XLM and DistilBERT don't use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
@@ -270,16 +296,16 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
nb_eval_steps += 1
|
nb_eval_steps += 1
|
||||||
if preds is None:
|
if preds is None:
|
||||||
preds = logits.detach().cpu().numpy()
|
preds = logits.detach().cpu().numpy()
|
||||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||||
else:
|
else:
|
||||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||||
|
|
||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
if args.output_mode == "classification":
|
if args.output_mode == "classification":
|
||||||
preds = np.argmax(preds, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
else:
|
else:
|
||||||
raise ValueError('No other `output_mode` for XNLI.')
|
raise ValueError("No other `output_mode` for XNLI.")
|
||||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
@@ -300,20 +326,27 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
processor = processors[task](language=args.language, train_language=args.train_language)
|
processor = processors[task](language=args.language, train_language=args.train_language)
|
||||||
output_mode = output_modes[task]
|
output_mode = output_modes[task]
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
'test' if evaluate else 'train',
|
args.data_dir,
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
"cached_{}_{}_{}_{}_{}".format(
|
||||||
|
"test" if evaluate else "train",
|
||||||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task),
|
str(task),
|
||||||
str(args.train_language if (not evaluate and args.train_language is not None) else args.language)))
|
str(args.train_language if (not evaluate and args.train_language is not None) else args.language),
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
examples = processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
examples = (
|
||||||
features = convert_examples_to_features(examples,
|
processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||||
|
)
|
||||||
|
features = convert_examples_to_features(
|
||||||
|
examples,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
label_list=label_list,
|
label_list=label_list,
|
||||||
max_length=args.max_seq_length,
|
max_length=args.max_seq_length,
|
||||||
@@ -336,7 +369,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||||
else:
|
else:
|
||||||
raise ValueError('No other `output_mode` for XNLI.')
|
raise ValueError("No other `output_mode` for XNLI.")
|
||||||
|
|
||||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||||
return dataset
|
return dataset
|
||||||
@@ -346,92 +379,152 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
"--data_dir",
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
default=None,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
type=str,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
required=True,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||||
parser.add_argument("--language", default=None, type=str, required=True,
|
)
|
||||||
help="Evaluation language. Also train language if `train_language` is set to None.")
|
parser.add_argument(
|
||||||
parser.add_argument("--train_language", default=None, type=str,
|
"--model_type",
|
||||||
help="Train language if is different of the evaluation language.")
|
default=None,
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
type=str,
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Evaluation language. Also train language if `train_language` is set to None.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_language", default=None, type=str, help="Train language if is different of the evaluation language."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
"--tokenizer_name",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
default="",
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.")
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
|
||||||
help="Whether to run eval on the test set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
help="Rul evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument(
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
"--gradient_accumulation_steps",
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
type=int,
|
||||||
help="The initial learning rate for Adam.")
|
default=1,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
help="Weight deay if we apply some.")
|
)
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Avoid using CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument(
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
"--fp16",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
action="store_true",
|
||||||
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
)
|
||||||
help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -443,22 +536,30 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
# Prepare XNLI task
|
# Prepare XNLI task
|
||||||
args.task_name = 'xnli'
|
args.task_name = "xnli"
|
||||||
if args.task_name not in processors:
|
if args.task_name not in processors:
|
||||||
raise ValueError("Task not found: %s" % (args.task_name))
|
raise ValueError("Task not found: %s" % (args.task_name))
|
||||||
processor = processors[args.task_name](language=args.language, train_language=args.train_language)
|
processor = processors[args.task_name](language=args.language, train_language=args.train_language)
|
||||||
@@ -472,17 +573,23 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
finetuning_task=args.task_name,
|
finetuning_task=args.task_name,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -491,14 +598,12 @@ def main():
|
|||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -508,36 +613,39 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -34,12 +34,30 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||||
|
|
||||||
|
|
||||||
BertAbsConfig = namedtuple(
|
BertAbsConfig = namedtuple(
|
||||||
"BertAbsConfig",
|
"BertAbsConfig",
|
||||||
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
[
|
||||||
|
"temp_dir",
|
||||||
|
"large",
|
||||||
|
"use_bert_emb",
|
||||||
|
"finetune_bert",
|
||||||
|
"encoder",
|
||||||
|
"share_emb",
|
||||||
|
"max_pos",
|
||||||
|
"enc_layers",
|
||||||
|
"enc_hidden_size",
|
||||||
|
"enc_heads",
|
||||||
|
"enc_ff_size",
|
||||||
|
"enc_dropout",
|
||||||
|
"dec_layers",
|
||||||
|
"dec_hidden_size",
|
||||||
|
"dec_heads",
|
||||||
|
"dec_ff_size",
|
||||||
|
"dec_dropout",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -119,7 +137,9 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
|||||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||||
output_original_generator = original.generator(output_original_model)
|
output_original_generator = original.generator(output_original_model)
|
||||||
|
|
||||||
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
output_converted_model = new_model(
|
||||||
|
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||||
|
)[0]
|
||||||
output_converted_generator = new_model.generator(output_converted_model)
|
output_converted_generator = new_model.generator(output_converted_model)
|
||||||
|
|
||||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||||
@@ -136,28 +156,21 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
|||||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||||
# directory structure. We save the state_dict instead.
|
# directory structure. We save the state_dict instead.
|
||||||
logging.info("saving the model's state dictionary")
|
logging.info("saving the model's state dictionary")
|
||||||
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
|
torch.save(
|
||||||
|
new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bertabs_checkpoint_path",
|
"--bertabs_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump.",
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path the official PyTorch dump.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pytorch_dump_folder_path",
|
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model.",
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the output PyTorch model.",
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
convert_bertabs_checkpoints(
|
convert_bertabs_checkpoints(
|
||||||
args.bertabs_checkpoint_path,
|
args.bertabs_checkpoint_path, args.pytorch_dump_folder_path,
|
||||||
args.pytorch_dump_folder_path,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,40 +56,22 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
||||||
if load_bert_pretrained_extractive:
|
if load_bert_pretrained_extractive:
|
||||||
self.bert.model.load_state_dict(
|
self.bert.model.load_state_dict(
|
||||||
dict(
|
dict([(n[11:], p) for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")]),
|
||||||
[
|
|
||||||
(n[11:], p)
|
|
||||||
for n, p in bert_extractive_checkpoint.items()
|
|
||||||
if n.startswith("bert.model")
|
|
||||||
]
|
|
||||||
),
|
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vocab_size = self.bert.model.config.vocab_size
|
self.vocab_size = self.bert.model.config.vocab_size
|
||||||
|
|
||||||
if args.max_pos > 512:
|
if args.max_pos > 512:
|
||||||
my_pos_embeddings = nn.Embedding(
|
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
|
||||||
args.max_pos, self.bert.model.config.hidden_size
|
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||||
)
|
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||||
my_pos_embeddings.weight.data[
|
|
||||||
:512
|
|
||||||
] = self.bert.model.embeddings.position_embeddings.weight.data
|
|
||||||
my_pos_embeddings.weight.data[
|
|
||||||
512:
|
|
||||||
] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
|
||||||
None, :
|
None, :
|
||||||
].repeat(
|
].repeat(args.max_pos - 512, 1)
|
||||||
args.max_pos - 512, 1
|
|
||||||
)
|
|
||||||
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
||||||
tgt_embeddings = nn.Embedding(
|
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
|
||||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
|
||||||
)
|
|
||||||
|
|
||||||
tgt_embeddings.weight = copy.deepcopy(
|
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
|
||||||
self.bert.model.embeddings.word_embeddings.weight
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
self.args.dec_layers,
|
self.args.dec_layers,
|
||||||
@@ -102,9 +84,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
gen_func = nn.LogSoftmax(dim=-1)
|
gen_func = nn.LogSoftmax(dim=-1)
|
||||||
self.generator = nn.Sequential(
|
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func)
|
||||||
nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func
|
|
||||||
)
|
|
||||||
self.generator[0].weight = self.decoder.embeddings.weight
|
self.generator[0].weight = self.decoder.embeddings.weight
|
||||||
|
|
||||||
load_from_checkpoints = False if checkpoint is None else True
|
load_from_checkpoints = False if checkpoint is None else True
|
||||||
@@ -127,25 +107,14 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
p.data.zero_()
|
p.data.zero_()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
|
||||||
encoder_input_ids,
|
|
||||||
decoder_input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
encoder_attention_mask,
|
|
||||||
decoder_attention_mask,
|
|
||||||
):
|
):
|
||||||
encoder_output = self.bert(
|
encoder_output = self.bert(
|
||||||
input_ids=encoder_input_ids,
|
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
attention_mask=encoder_attention_mask,
|
|
||||||
)
|
)
|
||||||
encoder_hidden_states = encoder_output[0]
|
encoder_hidden_states = encoder_output[0]
|
||||||
dec_state = self.decoder.init_decoder_state(
|
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
||||||
encoder_input_ids, encoder_hidden_states
|
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state)
|
||||||
)
|
|
||||||
decoder_outputs, _ = self.decoder(
|
|
||||||
decoder_input_ids[:, :-1], encoder_hidden_states, dec_state
|
|
||||||
)
|
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -162,10 +131,7 @@ class Bert(nn.Module):
|
|||||||
self.eval()
|
self.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
encoder_outputs, _ = self.model(
|
encoder_outputs, _ = self.model(
|
||||||
input_ids,
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
@@ -196,10 +162,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
|
|
||||||
# Build TransformerDecoder.
|
# Build TransformerDecoder.
|
||||||
self.transformer_layers = nn.ModuleList(
|
self.transformer_layers = nn.ModuleList(
|
||||||
[
|
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]
|
||||||
TransformerDecoderLayer(d_model, heads, d_ff, dropout)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||||
@@ -236,20 +199,14 @@ class TransformerDecoder(nn.Module):
|
|||||||
# Decoder padding mask
|
# Decoder padding mask
|
||||||
tgt_words = tgt
|
tgt_words = tgt
|
||||||
tgt_batch, tgt_len = tgt_words.size()
|
tgt_batch, tgt_len = tgt_words.size()
|
||||||
tgt_pad_mask = (
|
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||||
tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Encoder padding mask
|
# Encoder padding mask
|
||||||
if memory_mask is not None:
|
if memory_mask is not None:
|
||||||
src_len = memory_mask.size(-1)
|
src_len = memory_mask.size(-1)
|
||||||
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
||||||
else:
|
else:
|
||||||
src_pad_mask = (
|
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len)
|
||||||
src_words.data.eq(padding_idx)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(src_batch, tgt_len, src_len)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass through the embeddings
|
# Pass through the embeddings
|
||||||
emb = self.embeddings(input_ids)
|
emb = self.embeddings(input_ids)
|
||||||
@@ -271,9 +228,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
src_pad_mask,
|
src_pad_mask,
|
||||||
tgt_pad_mask,
|
tgt_pad_mask,
|
||||||
previous_input=prev_layer_input,
|
previous_input=prev_layer_input,
|
||||||
layer_cache=state.cache["layer_{}".format(i)]
|
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None,
|
||||||
if state.cache is not None
|
|
||||||
else None,
|
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
if state.cache is None:
|
if state.cache is None:
|
||||||
@@ -303,9 +258,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
def __init__(self, dropout, dim, max_len=5000):
|
def __init__(self, dropout, dim, max_len=5000):
|
||||||
pe = torch.zeros(max_len, dim)
|
pe = torch.zeros(max_len, dim)
|
||||||
position = torch.arange(0, max_len).unsqueeze(1)
|
position = torch.arange(0, max_len).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)))
|
||||||
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
|
||||||
)
|
|
||||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||||
pe = pe.unsqueeze(0)
|
pe = pe.unsqueeze(0)
|
||||||
@@ -356,14 +309,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
self.register_buffer("mask", mask)
|
self.register_buffer("mask", mask)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
|
||||||
inputs,
|
|
||||||
memory_bank,
|
|
||||||
src_pad_mask,
|
|
||||||
tgt_pad_mask,
|
|
||||||
previous_input=None,
|
|
||||||
layer_cache=None,
|
|
||||||
step=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -380,34 +326,20 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
* all_input `[batch_size x current_step x model_dim]`
|
* all_input `[batch_size x current_step x model_dim]`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dec_mask = torch.gt(
|
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0)
|
||||||
tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0
|
|
||||||
)
|
|
||||||
input_norm = self.layer_norm_1(inputs)
|
input_norm = self.layer_norm_1(inputs)
|
||||||
all_input = input_norm
|
all_input = input_norm
|
||||||
if previous_input is not None:
|
if previous_input is not None:
|
||||||
all_input = torch.cat((previous_input, input_norm), dim=1)
|
all_input = torch.cat((previous_input, input_norm), dim=1)
|
||||||
dec_mask = None
|
dec_mask = None
|
||||||
|
|
||||||
query = self.self_attn(
|
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
|
||||||
all_input,
|
|
||||||
all_input,
|
|
||||||
input_norm,
|
|
||||||
mask=dec_mask,
|
|
||||||
layer_cache=layer_cache,
|
|
||||||
type="self",
|
|
||||||
)
|
|
||||||
|
|
||||||
query = self.drop(query) + inputs
|
query = self.drop(query) + inputs
|
||||||
|
|
||||||
query_norm = self.layer_norm_2(query)
|
query_norm = self.layer_norm_2(query)
|
||||||
mid = self.context_attn(
|
mid = self.context_attn(
|
||||||
memory_bank,
|
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
|
||||||
memory_bank,
|
|
||||||
query_norm,
|
|
||||||
mask=src_pad_mask,
|
|
||||||
layer_cache=layer_cache,
|
|
||||||
type="context",
|
|
||||||
)
|
)
|
||||||
output = self.feed_forward(self.drop(mid) + query)
|
output = self.feed_forward(self.drop(mid) + query)
|
||||||
|
|
||||||
@@ -492,14 +424,7 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
self.final_linear = nn.Linear(model_dim, model_dim)
|
self.final_linear = nn.Linear(model_dim, model_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
|
||||||
key,
|
|
||||||
value,
|
|
||||||
query,
|
|
||||||
mask=None,
|
|
||||||
layer_cache=None,
|
|
||||||
type=None,
|
|
||||||
predefined_graph_1=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the context vector and the attention vectors.
|
Compute the context vector and the attention vectors.
|
||||||
@@ -531,11 +456,7 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
|
|
||||||
def unshape(x):
|
def unshape(x):
|
||||||
""" compute context """
|
""" compute context """
|
||||||
return (
|
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
||||||
x.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size, -1, head_count * dim_per_head)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1) Project key, value, and query.
|
# 1) Project key, value, and query.
|
||||||
if layer_cache is not None:
|
if layer_cache is not None:
|
||||||
@@ -554,9 +475,7 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
if layer_cache["self_keys"] is not None:
|
if layer_cache["self_keys"] is not None:
|
||||||
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
||||||
if layer_cache["self_values"] is not None:
|
if layer_cache["self_values"] is not None:
|
||||||
value = torch.cat(
|
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2)
|
||||||
(layer_cache["self_values"].to(device), value), dim=2
|
|
||||||
)
|
|
||||||
layer_cache["self_keys"] = key
|
layer_cache["self_keys"] = key
|
||||||
layer_cache["self_values"] = value
|
layer_cache["self_values"] = value
|
||||||
elif type == "context":
|
elif type == "context":
|
||||||
@@ -637,13 +556,9 @@ class DecoderState(object):
|
|||||||
sizes = e.size()
|
sizes = e.size()
|
||||||
br = sizes[1]
|
br = sizes[1]
|
||||||
if len(sizes) == 3:
|
if len(sizes) == 3:
|
||||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx]
|
||||||
:, :, idx
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
sent_states = e.view(
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx]
|
||||||
sizes[0], beam_size, br // beam_size, sizes[2], sizes[3]
|
|
||||||
)[:, :, idx]
|
|
||||||
|
|
||||||
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
||||||
|
|
||||||
@@ -716,11 +631,7 @@ class TransformerDecoderState(DecoderState):
|
|||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return (
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
0.5
|
|
||||||
* x
|
|
||||||
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PositionwiseFeedForward(nn.Module):
|
class PositionwiseFeedForward(nn.Module):
|
||||||
@@ -758,9 +669,7 @@ class PositionwiseFeedForward(nn.Module):
|
|||||||
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
||||||
# we should be able to refactor the global scorer a lot
|
# we should be able to refactor the global scorer a lot
|
||||||
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
||||||
translator = Translator(
|
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
|
||||||
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger
|
|
||||||
)
|
|
||||||
return translator
|
return translator
|
||||||
|
|
||||||
|
|
||||||
@@ -891,9 +800,7 @@ class Translator(object):
|
|||||||
Shouldn't need the original dataset.
|
Shouldn't need the original dataset.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self._fast_translate_batch(
|
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
||||||
batch, self.max_length, min_length=self.min_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# Where the beam search lives
|
# Where the beam search lives
|
||||||
# I have no idea why it is being called from the method above
|
# I have no idea why it is being called from the method above
|
||||||
@@ -912,26 +819,18 @@ class Translator(object):
|
|||||||
mask_src = batch.mask_src
|
mask_src = batch.mask_src
|
||||||
|
|
||||||
src_features = self.model.bert(src, segs, mask_src)
|
src_features = self.model.bert(src, segs, mask_src)
|
||||||
dec_states = self.model.decoder.init_decoder_state(
|
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||||
src, src_features, with_cache=True
|
|
||||||
)
|
|
||||||
device = src_features.device
|
device = src_features.device
|
||||||
|
|
||||||
# Tile states and memory beam_size times.
|
# Tile states and memory beam_size times.
|
||||||
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
||||||
src_features = tile(src_features, beam_size, dim=0)
|
src_features = tile(src_features, beam_size, dim=0)
|
||||||
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
||||||
beam_offset = torch.arange(
|
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
|
||||||
0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device
|
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
|
||||||
)
|
|
||||||
alive_seq = torch.full(
|
|
||||||
[batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Give full probability to the first beam on the first step.
|
# Give full probability to the first beam on the first step.
|
||||||
topk_log_probs = torch.tensor(
|
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
|
||||||
[0.0] + [float("-inf")] * (beam_size - 1), device=device
|
|
||||||
).repeat(batch_size)
|
|
||||||
|
|
||||||
# Structure that holds finished hypotheses.
|
# Structure that holds finished hypotheses.
|
||||||
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
||||||
@@ -948,9 +847,7 @@ class Translator(object):
|
|||||||
# Decoder forward.
|
# Decoder forward.
|
||||||
decoder_input = decoder_input.transpose(0, 1)
|
decoder_input = decoder_input.transpose(0, 1)
|
||||||
|
|
||||||
dec_out, dec_states = self.model.decoder(
|
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
||||||
decoder_input, src_features, dec_states, step=step
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generator forward.
|
# Generator forward.
|
||||||
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
||||||
@@ -978,10 +875,7 @@ class Translator(object):
|
|||||||
words = " ".join(words).replace(" ##", "").split()
|
words = " ".join(words).replace(" ##", "").split()
|
||||||
if len(words) <= 3:
|
if len(words) <= 3:
|
||||||
continue
|
continue
|
||||||
trigrams = [
|
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
|
||||||
(words[i - 1], words[i], words[i + 1])
|
|
||||||
for i in range(1, len(words) - 1)
|
|
||||||
]
|
|
||||||
trigram = tuple(trigrams[-1])
|
trigram = tuple(trigrams[-1])
|
||||||
if trigram in trigrams[:-1]:
|
if trigram in trigrams[:-1]:
|
||||||
fail = True
|
fail = True
|
||||||
@@ -999,15 +893,11 @@ class Translator(object):
|
|||||||
topk_ids = topk_ids.fmod(vocab_size)
|
topk_ids = topk_ids.fmod(vocab_size)
|
||||||
|
|
||||||
# Map beam_index to batch_index in the flat representation.
|
# Map beam_index to batch_index in the flat representation.
|
||||||
batch_index = topk_beam_index + beam_offset[
|
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
||||||
: topk_beam_index.size(0)
|
|
||||||
].unsqueeze(1)
|
|
||||||
select_indices = batch_index.view(-1)
|
select_indices = batch_index.view(-1)
|
||||||
|
|
||||||
# Append last prediction.
|
# Append last prediction.
|
||||||
alive_seq = torch.cat(
|
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
|
||||||
[alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1
|
|
||||||
)
|
|
||||||
|
|
||||||
is_finished = topk_ids.eq(self.end_token)
|
is_finished = topk_ids.eq(self.end_token)
|
||||||
if step + 1 == max_length:
|
if step + 1 == max_length:
|
||||||
@@ -1040,15 +930,11 @@ class Translator(object):
|
|||||||
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
||||||
batch_index = batch_index.index_select(0, non_finished)
|
batch_index = batch_index.index_select(0, non_finished)
|
||||||
batch_offset = batch_offset.index_select(0, non_finished)
|
batch_offset = batch_offset.index_select(0, non_finished)
|
||||||
alive_seq = predictions.index_select(0, non_finished).view(
|
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
|
||||||
-1, alive_seq.size(-1)
|
|
||||||
)
|
|
||||||
# Reorder states.
|
# Reorder states.
|
||||||
select_indices = batch_index.view(-1)
|
select_indices = batch_index.view(-1)
|
||||||
src_features = src_features.index_select(0, select_indices)
|
src_features = src_features.index_select(0, select_indices)
|
||||||
dec_states.map_batch_fn(
|
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
||||||
lambda state, dim: state.index_select(dim, select_indices)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -1089,14 +975,7 @@ def tile(x, count, dim=0):
|
|||||||
out_size = list(x.size())
|
out_size = list(x.size())
|
||||||
out_size[0] *= count
|
out_size[0] *= count
|
||||||
batch = x.size(0)
|
batch = x.size(0)
|
||||||
x = (
|
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size)
|
||||||
x.view(batch, -1)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.repeat(count, 1)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
.view(*out_size)
|
|
||||||
)
|
|
||||||
if dim != 0:
|
if dim != 0:
|
||||||
x = x.permute(perm).contiguous()
|
x = x.permute(perm).contiguous()
|
||||||
return x
|
return x
|
||||||
@@ -1107,6 +986,7 @@ def tile(x, count, dim=0):
|
|||||||
# a finetuning script.
|
# a finetuning script.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
class BertSumOptimizer(object):
|
class BertSumOptimizer(object):
|
||||||
""" Specific optimizer for BertSum.
|
""" Specific optimizer for BertSum.
|
||||||
|
|
||||||
@@ -1126,16 +1006,10 @@ class BertSumOptimizer(object):
|
|||||||
|
|
||||||
self.optimizers = {
|
self.optimizers = {
|
||||||
"encoder": torch.optim.Adam(
|
"encoder": torch.optim.Adam(
|
||||||
model.encoder.parameters(),
|
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
|
||||||
lr=lr["encoder"],
|
|
||||||
betas=(beta_1, beta_2),
|
|
||||||
eps=eps,
|
|
||||||
),
|
),
|
||||||
"decoder": torch.optim.Adam(
|
"decoder": torch.optim.Adam(
|
||||||
model.decoder.parameters(),
|
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
|
||||||
lr=lr["decoder"],
|
|
||||||
betas=(beta_1, beta_2),
|
|
||||||
eps=eps,
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1143,9 +1017,7 @@ class BertSumOptimizer(object):
|
|||||||
self.current_learning_rates = {}
|
self.current_learning_rates = {}
|
||||||
|
|
||||||
def _update_rate(self, stack):
|
def _update_rate(self, stack):
|
||||||
return self.lr[stack] * min(
|
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5))
|
||||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)
|
|
||||||
)
|
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
self.optimizer_decoder.zero_grad()
|
self.optimizer_decoder.zero_grad()
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
Batch = namedtuple(
|
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||||
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args):
|
def evaluate(args):
|
||||||
@@ -48,13 +46,14 @@ def evaluate(args):
|
|||||||
|
|
||||||
import rouge
|
import rouge
|
||||||
import nltk
|
import nltk
|
||||||
nltk.download('punkt')
|
|
||||||
|
nltk.download("punkt")
|
||||||
rouge_evaluator = rouge.Rouge(
|
rouge_evaluator = rouge.Rouge(
|
||||||
metrics=['rouge-n', 'rouge-l'],
|
metrics=["rouge-n", "rouge-l"],
|
||||||
max_n=2,
|
max_n=2,
|
||||||
limit_length=True,
|
limit_length=True,
|
||||||
length_limit=args.beam_size,
|
length_limit=args.beam_size,
|
||||||
length_limit_type='words',
|
length_limit_type="words",
|
||||||
apply_avg=True,
|
apply_avg=True,
|
||||||
apply_best=False,
|
apply_best=False,
|
||||||
alpha=0.5, # Default F1_score
|
alpha=0.5, # Default F1_score
|
||||||
@@ -161,15 +160,15 @@ Recall >> {:.3f}
|
|||||||
F1 >> {:.3f}
|
F1 >> {:.3f}
|
||||||
Precision >> {:.3f}
|
Precision >> {:.3f}
|
||||||
Recall >> {:.3f}""".format(
|
Recall >> {:.3f}""".format(
|
||||||
scores['rouge-1']['f'],
|
scores["rouge-1"]["f"],
|
||||||
scores['rouge-1']['p'],
|
scores["rouge-1"]["p"],
|
||||||
scores['rouge-1']['r'],
|
scores["rouge-1"]["r"],
|
||||||
scores['rouge-2']['f'],
|
scores["rouge-2"]["f"],
|
||||||
scores['rouge-2']['p'],
|
scores["rouge-2"]["p"],
|
||||||
scores['rouge-2']['r'],
|
scores["rouge-2"]["r"],
|
||||||
scores['rouge-l']['f'],
|
scores["rouge-l"]["f"],
|
||||||
scores['rouge-l']['p'],
|
scores["rouge-l"]["p"],
|
||||||
scores['rouge-l']['r'],
|
scores["rouge-l"]["r"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -187,9 +186,7 @@ def build_data_iterator(args, tokenizer):
|
|||||||
dataset = load_and_cache_examples(args, tokenizer)
|
dataset = load_and_cache_examples(args, tokenizer)
|
||||||
sampler = SequentialSampler(dataset)
|
sampler = SequentialSampler(dataset)
|
||||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||||
iterator = DataLoader(
|
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
|
||||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return iterator
|
return iterator
|
||||||
|
|
||||||
@@ -210,14 +207,9 @@ def collate(data, tokenizer, block_size, device):
|
|||||||
names = [name for name, _, _ in data]
|
names = [name for name, _, _ in data]
|
||||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||||
|
|
||||||
encoded_text = [
|
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
|
||||||
]
|
|
||||||
encoded_stories = torch.tensor(
|
encoded_stories = torch.tensor(
|
||||||
[
|
[fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
|
||||||
for story, _ in encoded_text
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||||
@@ -272,38 +264,23 @@ def main():
|
|||||||
)
|
)
|
||||||
# EVALUATION options
|
# EVALUATION options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_cuda",
|
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||||
default=False,
|
|
||||||
type=bool,
|
|
||||||
help="Whether to force the execution on CPU.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||||
)
|
)
|
||||||
# BEAM SEARCH arguments
|
# BEAM SEARCH arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min_length",
|
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
|
||||||
default=50,
|
|
||||||
type=int,
|
|
||||||
help="Minimum number of tokens for the summaries.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_length",
|
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
|
||||||
default=200,
|
|
||||||
type=int,
|
|
||||||
help="Maixmum number of tokens for the summaries.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam_size",
|
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
|
||||||
default=5,
|
|
||||||
type=int,
|
|
||||||
help="The number of beams to start with for each example.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--alpha",
|
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
|
||||||
default=0.95,
|
|
||||||
type=float,
|
|
||||||
help="The value of alpha for the length penalty in the beam search.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--block_trigram",
|
"--block_trigram",
|
||||||
|
|||||||
@@ -68,9 +68,7 @@ def process_story(raw_story):
|
|||||||
Raises:
|
Raises:
|
||||||
IndexError: If the stoy is empty or contains no highlights.
|
IndexError: If the stoy is empty or contains no highlights.
|
||||||
"""
|
"""
|
||||||
nonempty_lines = list(
|
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
|
||||||
)
|
|
||||||
|
|
||||||
# for some unknown reason some lines miss a period, add it
|
# for some unknown reason some lines miss a period, add it
|
||||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||||
@@ -135,13 +133,9 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
|||||||
sentences.
|
sentences.
|
||||||
"""
|
"""
|
||||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||||
story_token_ids = [
|
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||||
token for sentence in story_lines_token_ids for token in sentence
|
|
||||||
]
|
|
||||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||||
summary_token_ids = [
|
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||||
token for sentence in summary_lines_token_ids for token in sentence
|
|
||||||
]
|
|
||||||
|
|
||||||
return story_token_ids, summary_token_ids
|
return story_token_ids, summary_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -33,25 +33,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||||
sequence = [1, 2, 3, 4]
|
sequence = [1, 2, 3, 4]
|
||||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||||
self.assertEqual(
|
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_fit_to_block_sequence_fit_exactly(self):
|
def test_fit_to_block_sequence_fit_exactly(self):
|
||||||
""" Do nothing if the sequence is the right size. """
|
""" Do nothing if the sequence is the right size. """
|
||||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
self.assertEqual(
|
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_fit_to_block_sequence_too_big(self):
|
def test_fit_to_block_sequence_too_big(self):
|
||||||
""" Truncate the sequence if it is too long. """
|
""" Truncate the sequence if it is too long. """
|
||||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
self.assertEqual(
|
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_process_story_no_highlights(self):
|
def test_process_story_no_highlights(self):
|
||||||
""" Processing a story with no highlights returns an empty list for the summary.
|
""" Processing a story with no highlights returns an empty list for the summary.
|
||||||
@@ -95,9 +89,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
def test_build_mask(self):
|
def test_build_mask(self):
|
||||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||||
build_mask(sequence, 23).numpy(), expected.numpy()
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_build_mask_with_padding_equal_to_one(self):
|
def test_build_mask_with_padding_equal_to_one(self):
|
||||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||||
@@ -106,12 +98,8 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_compute_token_type_ids(self):
|
def test_compute_token_type_ids(self):
|
||||||
separator = 101
|
separator = 101
|
||||||
batch = torch.tensor(
|
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||||
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||||
)
|
|
||||||
expected = torch.tensor(
|
|
||||||
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = compute_token_type_ids(batch, separator)
|
result = compute_token_type_ids(batch, separator)
|
||||||
np.testing.assert_array_equal(result, expected)
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|||||||
@@ -35,19 +35,21 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def get_setup_file():
|
def get_setup_file():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-f')
|
parser.add_argument("-f")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args.f
|
return args.f
|
||||||
|
|
||||||
class ExamplesTests(unittest.TestCase):
|
|
||||||
|
|
||||||
|
class ExamplesTests(unittest.TestCase):
|
||||||
def test_run_glue(self):
|
def test_run_glue(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = ["run_glue.py",
|
testargs = [
|
||||||
|
"run_glue.py",
|
||||||
"--data_dir=./examples/tests_samples/MRPC/",
|
"--data_dir=./examples/tests_samples/MRPC/",
|
||||||
"--task_name=mrpc",
|
"--task_name=mrpc",
|
||||||
"--do_train",
|
"--do_train",
|
||||||
@@ -59,10 +61,10 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
"--max_steps=10",
|
"--max_steps=10",
|
||||||
"--warmup_steps=2",
|
"--warmup_steps=2",
|
||||||
"--overwrite_output_dir",
|
"--overwrite_output_dir",
|
||||||
"--seed=42"]
|
"--seed=42",
|
||||||
model_type, model_name = ("--model_type=bert",
|
]
|
||||||
"--model_name_or_path=bert-base-uncased")
|
model_type, model_name = ("--model_type=bert", "--model_name_or_path=bert-base-uncased")
|
||||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||||
result = run_glue.main()
|
result = run_glue.main()
|
||||||
for value in result.values():
|
for value in result.values():
|
||||||
self.assertGreaterEqual(value, 0.75)
|
self.assertGreaterEqual(value, 0.75)
|
||||||
@@ -71,7 +73,8 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = ["run_squad.py",
|
testargs = [
|
||||||
|
"run_squad.py",
|
||||||
"--data_dir=./examples/tests_samples/SQUAD",
|
"--data_dir=./examples/tests_samples/SQUAD",
|
||||||
"--model_name=bert-base-uncased",
|
"--model_name=bert-base-uncased",
|
||||||
"--output_dir=./examples/tests_samples/temp_dir",
|
"--output_dir=./examples/tests_samples/temp_dir",
|
||||||
@@ -84,27 +87,24 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
"--per_gpu_train_batch_size=2",
|
"--per_gpu_train_batch_size=2",
|
||||||
"--per_gpu_eval_batch_size=1",
|
"--per_gpu_eval_batch_size=1",
|
||||||
"--overwrite_output_dir",
|
"--overwrite_output_dir",
|
||||||
"--seed=42"]
|
"--seed=42",
|
||||||
model_type, model_name = ("--model_type=bert",
|
]
|
||||||
"--model_name_or_path=bert-base-uncased")
|
model_type, model_name = ("--model_type=bert", "--model_name_or_path=bert-base-uncased")
|
||||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||||
result = run_squad.main()
|
result = run_squad.main()
|
||||||
self.assertGreaterEqual(result['f1'], 30)
|
self.assertGreaterEqual(result["f1"], 30)
|
||||||
self.assertGreaterEqual(result['exact'], 30)
|
self.assertGreaterEqual(result["exact"], 30)
|
||||||
|
|
||||||
def test_generation(self):
|
def test_generation(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = ["run_generation.py",
|
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||||
"--prompt=Hello",
|
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
|
||||||
"--length=10",
|
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||||
"--seed=42"]
|
|
||||||
model_type, model_name = ("--model_type=openai-gpt",
|
|
||||||
"--model_name_or_path=openai-gpt")
|
|
||||||
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
|
|
||||||
result = run_generation.main()
|
result = run_generation.main()
|
||||||
self.assertGreaterEqual(len(result), 10)
|
self.assertGreaterEqual(len(result), 10)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -55,19 +55,10 @@ class InputExample(object):
|
|||||||
|
|
||||||
|
|
||||||
class InputFeatures(object):
|
class InputFeatures(object):
|
||||||
def __init__(self,
|
def __init__(self, example_id, choices_features, label):
|
||||||
example_id,
|
|
||||||
choices_features,
|
|
||||||
label
|
|
||||||
|
|
||||||
):
|
|
||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
self.choices_features = [
|
self.choices_features = [
|
||||||
{
|
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
|
||||||
'input_ids': input_ids,
|
|
||||||
'input_mask': input_mask,
|
|
||||||
'segment_ids': segment_ids
|
|
||||||
}
|
|
||||||
for input_ids, input_mask, segment_ids in choices_features
|
for input_ids, input_mask, segment_ids in choices_features
|
||||||
]
|
]
|
||||||
self.label = label
|
self.label = label
|
||||||
@@ -99,29 +90,29 @@ class RaceProcessor(DataProcessor):
|
|||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
logger.info("LOOKING AT {} train".format(data_dir))
|
logger.info("LOOKING AT {} train".format(data_dir))
|
||||||
high = os.path.join(data_dir, 'train/high')
|
high = os.path.join(data_dir, "train/high")
|
||||||
middle = os.path.join(data_dir, 'train/middle')
|
middle = os.path.join(data_dir, "train/middle")
|
||||||
high = self._read_txt(high)
|
high = self._read_txt(high)
|
||||||
middle = self._read_txt(middle)
|
middle = self._read_txt(middle)
|
||||||
return self._create_examples(high + middle, 'train')
|
return self._create_examples(high + middle, "train")
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||||
high = os.path.join(data_dir, 'dev/high')
|
high = os.path.join(data_dir, "dev/high")
|
||||||
middle = os.path.join(data_dir, 'dev/middle')
|
middle = os.path.join(data_dir, "dev/middle")
|
||||||
high = self._read_txt(high)
|
high = self._read_txt(high)
|
||||||
middle = self._read_txt(middle)
|
middle = self._read_txt(middle)
|
||||||
return self._create_examples(high + middle, 'dev')
|
return self._create_examples(high + middle, "dev")
|
||||||
|
|
||||||
def get_test_examples(self, data_dir):
|
def get_test_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
logger.info("LOOKING AT {} test".format(data_dir))
|
logger.info("LOOKING AT {} test".format(data_dir))
|
||||||
high = os.path.join(data_dir, 'test/high')
|
high = os.path.join(data_dir, "test/high")
|
||||||
middle = os.path.join(data_dir, 'test/middle')
|
middle = os.path.join(data_dir, "test/middle")
|
||||||
high = self._read_txt(high)
|
high = self._read_txt(high)
|
||||||
middle = self._read_txt(middle)
|
middle = self._read_txt(middle)
|
||||||
return self._create_examples(high + middle, 'test')
|
return self._create_examples(high + middle, "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -131,13 +122,12 @@ class RaceProcessor(DataProcessor):
|
|||||||
lines = []
|
lines = []
|
||||||
files = glob.glob(input_dir + "/*txt")
|
files = glob.glob(input_dir + "/*txt")
|
||||||
for file in tqdm.tqdm(files, desc="read files"):
|
for file in tqdm.tqdm(files, desc="read files"):
|
||||||
with open(file, 'r', encoding='utf-8') as fin:
|
with open(file, "r", encoding="utf-8") as fin:
|
||||||
data_raw = json.load(fin)
|
data_raw = json.load(fin)
|
||||||
data_raw["race_id"] = file
|
data_raw["race_id"] = file
|
||||||
lines.append(data_raw)
|
lines.append(data_raw)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training and dev sets."""
|
||||||
examples = []
|
examples = []
|
||||||
@@ -145,9 +135,9 @@ class RaceProcessor(DataProcessor):
|
|||||||
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
||||||
article = data_raw["article"]
|
article = data_raw["article"]
|
||||||
for i in range(len(data_raw["answers"])):
|
for i in range(len(data_raw["answers"])):
|
||||||
truth = str(ord(data_raw['answers'][i]) - ord('A'))
|
truth = str(ord(data_raw["answers"][i]) - ord("A"))
|
||||||
question = data_raw['questions'][i]
|
question = data_raw["questions"][i]
|
||||||
options = data_raw['options'][i]
|
options = data_raw["options"][i]
|
||||||
|
|
||||||
examples.append(
|
examples.append(
|
||||||
InputExample(
|
InputExample(
|
||||||
@@ -155,9 +145,12 @@ class RaceProcessor(DataProcessor):
|
|||||||
question=question,
|
question=question,
|
||||||
contexts=[article, article, article, article], # this is not efficient but convenient
|
contexts=[article, article, article, article], # this is not efficient but convenient
|
||||||
endings=[options[0], options[1], options[2], options[3]],
|
endings=[options[0], options[1], options[2], options[3]],
|
||||||
label=truth))
|
label=truth,
|
||||||
|
)
|
||||||
|
)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
class SwagProcessor(DataProcessor):
|
class SwagProcessor(DataProcessor):
|
||||||
"""Processor for the SWAG data set."""
|
"""Processor for the SWAG data set."""
|
||||||
|
|
||||||
@@ -179,27 +172,25 @@ class SwagProcessor(DataProcessor):
|
|||||||
"setting!"
|
"setting!"
|
||||||
)
|
)
|
||||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1", "2", "3"]
|
return ["0", "1", "2", "3"]
|
||||||
|
|
||||||
def _read_csv(self, input_file):
|
def _read_csv(self, input_file):
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
line = list(unicode(cell, "utf-8") for cell in line)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def _create_examples(self, lines: List[List[str]], type: str):
|
def _create_examples(self, lines: List[List[str]], type: str):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training and dev sets."""
|
||||||
if type == "train" and lines[0][-1] != 'label':
|
if type == "train" and lines[0][-1] != "label":
|
||||||
raise ValueError(
|
raise ValueError("For training, the input file must contain a label column.")
|
||||||
"For training, the input file must contain a label column."
|
|
||||||
)
|
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
InputExample(
|
InputExample(
|
||||||
@@ -209,8 +200,9 @@ class SwagProcessor(DataProcessor):
|
|||||||
# choice is stored in "sent2".
|
# choice is stored in "sent2".
|
||||||
contexts=[line[4], line[4], line[4], line[4]],
|
contexts=[line[4], line[4], line[4], line[4]],
|
||||||
endings=[line[7], line[8], line[9], line[10]],
|
endings=[line[7], line[8], line[9], line[10]],
|
||||||
label=line[11]
|
label=line[11],
|
||||||
) for line in lines[1:] # we skip the line with the column names
|
)
|
||||||
|
for line in lines[1:] # we skip the line with the column names
|
||||||
]
|
]
|
||||||
|
|
||||||
return examples
|
return examples
|
||||||
@@ -238,11 +230,10 @@ class ArcProcessor(DataProcessor):
|
|||||||
return ["0", "1", "2", "3"]
|
return ["0", "1", "2", "3"]
|
||||||
|
|
||||||
def _read_json(self, input_file):
|
def _read_json(self, input_file):
|
||||||
with open(input_file, 'r', encoding='utf-8') as fin:
|
with open(input_file, "r", encoding="utf-8") as fin:
|
||||||
lines = fin.readlines()
|
lines = fin.readlines()
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def _create_examples(self, lines, type):
|
def _create_examples(self, lines, type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training and dev sets."""
|
||||||
|
|
||||||
@@ -285,10 +276,16 @@ class ArcProcessor(DataProcessor):
|
|||||||
InputExample(
|
InputExample(
|
||||||
example_id=id,
|
example_id=id,
|
||||||
question=question,
|
question=question,
|
||||||
contexts=[options[0]["para"].replace("_", ""), options[1]["para"].replace("_", ""),
|
contexts=[
|
||||||
options[2]["para"].replace("_", ""), options[3]["para"].replace("_", "")],
|
options[0]["para"].replace("_", ""),
|
||||||
|
options[1]["para"].replace("_", ""),
|
||||||
|
options[2]["para"].replace("_", ""),
|
||||||
|
options[3]["para"].replace("_", ""),
|
||||||
|
],
|
||||||
endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]],
|
endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]],
|
||||||
label=truth))
|
label=truth,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if type == "train":
|
if type == "train":
|
||||||
assert len(examples) > 1
|
assert len(examples) > 1
|
||||||
@@ -331,16 +328,13 @@ def convert_examples_to_features(
|
|||||||
else:
|
else:
|
||||||
text_b = example.question + " " + ending
|
text_b = example.question + " " + ending
|
||||||
|
|
||||||
inputs = tokenizer.encode_plus(
|
inputs = tokenizer.encode_plus(text_a, text_b, add_special_tokens=True, max_length=max_length,)
|
||||||
text_a,
|
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
||||||
text_b,
|
logger.info(
|
||||||
add_special_tokens=True,
|
"Attention! you are cropping tokens (swag task is ok). "
|
||||||
max_length=max_length,
|
"If you are training ARC and RACE and you are poping question + options,"
|
||||||
|
"you need to try to use a bigger max seq length!"
|
||||||
)
|
)
|
||||||
if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0:
|
|
||||||
logger.info('Attention! you are cropping tokens (swag task is ok). '
|
|
||||||
'If you are training ARC and RACE and you are poping question + options,'
|
|
||||||
'you need to try to use a bigger max seq length!')
|
|
||||||
|
|
||||||
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||||
|
|
||||||
@@ -364,7 +358,6 @@ def convert_examples_to_features(
|
|||||||
assert len(token_type_ids) == max_length
|
assert len(token_type_ids) == max_length
|
||||||
choices_features.append((input_ids, attention_mask, token_type_ids))
|
choices_features.append((input_ids, attention_mask, token_type_ids))
|
||||||
|
|
||||||
|
|
||||||
label = label_map[example.label]
|
label = label_map[example.label]
|
||||||
|
|
||||||
if ex_index < 2:
|
if ex_index < 2:
|
||||||
@@ -372,33 +365,17 @@ def convert_examples_to_features(
|
|||||||
logger.info("race_id: {}".format(example.example_id))
|
logger.info("race_id: {}".format(example.example_id))
|
||||||
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
|
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
|
||||||
logger.info("choice: {}".format(choice_idx))
|
logger.info("choice: {}".format(choice_idx))
|
||||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
||||||
logger.info("attention_mask: {}".format(' '.join(map(str, attention_mask))))
|
logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
|
||||||
logger.info("token_type_ids: {}".format(' '.join(map(str, token_type_ids))))
|
logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
|
||||||
logger.info("label: {}".format(label))
|
logger.info("label: {}".format(label))
|
||||||
|
|
||||||
features.append(
|
features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label,))
|
||||||
InputFeatures(
|
|
||||||
example_id=example.example_id,
|
|
||||||
choices_features=choices_features,
|
|
||||||
label=label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor}
|
||||||
|
|
||||||
|
|
||||||
processors = {
|
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4}
|
||||||
"race": RaceProcessor,
|
|
||||||
"swag": SwagProcessor,
|
|
||||||
"arc": ArcProcessor
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {
|
|
||||||
"race", 4,
|
|
||||||
"swag", 4,
|
|
||||||
"arc", 4
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -61,9 +61,7 @@ def read_examples_from_file(data_dir, mode):
|
|||||||
for line in f:
|
for line in f:
|
||||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||||
if words:
|
if words:
|
||||||
examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
|
examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words, labels=labels))
|
||||||
words=words,
|
|
||||||
labels=labels))
|
|
||||||
guid_index += 1
|
guid_index += 1
|
||||||
words = []
|
words = []
|
||||||
labels = []
|
labels = []
|
||||||
@@ -76,13 +74,12 @@ def read_examples_from_file(data_dir, mode):
|
|||||||
# Examples could have no label for mode = "test"
|
# Examples could have no label for mode = "test"
|
||||||
labels.append("O")
|
labels.append("O")
|
||||||
if words:
|
if words:
|
||||||
examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
|
examples.append(InputExample(guid="%s-%d".format(mode, guid_index), words=words, labels=labels))
|
||||||
words=words,
|
|
||||||
labels=labels))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples,
|
def convert_examples_to_features(
|
||||||
|
examples,
|
||||||
label_list,
|
label_list,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -96,7 +93,8 @@ def convert_examples_to_features(examples,
|
|||||||
pad_token_segment_id=0,
|
pad_token_segment_id=0,
|
||||||
pad_token_label_id=-100,
|
pad_token_label_id=-100,
|
||||||
sequence_a_segment_id=0,
|
sequence_a_segment_id=0,
|
||||||
mask_padding_with_zero=True):
|
mask_padding_with_zero=True,
|
||||||
|
):
|
||||||
""" Loads a data file into a list of `InputBatch`s
|
""" Loads a data file into a list of `InputBatch`s
|
||||||
`cls_token_at_end` define the location of the CLS token:
|
`cls_token_at_end` define the location of the CLS token:
|
||||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||||
@@ -174,10 +172,10 @@ def convert_examples_to_features(examples,
|
|||||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||||
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
||||||
else:
|
else:
|
||||||
input_ids += ([pad_token] * padding_length)
|
input_ids += [pad_token] * padding_length
|
||||||
input_mask += ([0 if mask_padding_with_zero else 1] * padding_length)
|
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||||||
segment_ids += ([pad_token_segment_id] * padding_length)
|
segment_ids += [pad_token_segment_id] * padding_length
|
||||||
label_ids += ([pad_token_label_id] * padding_length)
|
label_ids += [pad_token_label_id] * padding_length
|
||||||
|
|
||||||
assert len(input_ids) == max_seq_length
|
assert len(input_ids) == max_seq_length
|
||||||
assert len(input_mask) == max_seq_length
|
assert len(input_mask) == max_seq_length
|
||||||
@@ -194,10 +192,8 @@ def convert_examples_to_features(examples,
|
|||||||
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
||||||
|
|
||||||
features.append(
|
features.append(
|
||||||
InputFeatures(input_ids=input_ids,
|
InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids)
|
||||||
input_mask=input_mask,
|
)
|
||||||
segment_ids=segment_ids,
|
|
||||||
label_ids=label_ids))
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
hubconf.py
11
hubconf.py
@@ -1,9 +1,15 @@
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer, AutoConfig, AutoModel, AutoModelWithLMHead, AutoModelForSequenceClassification, AutoModelForQuestionAnswering
|
AutoTokenizer,
|
||||||
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
AutoModelWithLMHead,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForQuestionAnswering,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import add_start_docstrings
|
from transformers.file_utils import add_start_docstrings
|
||||||
|
|
||||||
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex', 'sentencepiece', 'sacremoses']
|
dependencies = ["torch", "tqdm", "boto3", "requests", "regex", "sentencepiece", "sacremoses"]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoConfig.__doc__)
|
@add_start_docstrings(AutoConfig.__doc__)
|
||||||
def config(*args, **kwargs):
|
def config(*args, **kwargs):
|
||||||
@@ -57,6 +63,7 @@ def model(*args, **kwargs):
|
|||||||
|
|
||||||
return AutoModel.from_pretrained(*args, **kwargs)
|
return AutoModel.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoModelWithLMHead.__doc__)
|
@add_start_docstrings(AutoModelWithLMHead.__doc__)
|
||||||
def modelWithLMHead(*args, **kwargs):
|
def modelWithLMHead(*args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
47
setup.py
47
setup.py
@@ -38,11 +38,11 @@ from setuptools import find_packages, setup
|
|||||||
|
|
||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
'serving': ['pydantic', 'uvicorn', 'fastapi'],
|
"serving": ["pydantic", "uvicorn", "fastapi"],
|
||||||
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
|
"serving-tf": ["pydantic", "uvicorn", "fastapi", "tensorflow"],
|
||||||
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
|
"serving-torch": ["pydantic", "uvicorn", "fastapi", "torch"],
|
||||||
}
|
}
|
||||||
extras['all'] = [package for package in extras.values()]
|
extras["all"] = [package for package in extras.values()]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="transformers",
|
name="transformers",
|
||||||
@@ -50,30 +50,29 @@ setup(
|
|||||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||||
author_email="thomas@huggingface.co",
|
author_email="thomas@huggingface.co",
|
||||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords='NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU',
|
keywords="NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU",
|
||||||
license='Apache',
|
license="Apache",
|
||||||
url="https://github.com/huggingface/transformers",
|
url="https://github.com/huggingface/transformers",
|
||||||
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||||
"tests.*", "tests"]),
|
install_requires=[
|
||||||
install_requires=['numpy',
|
"numpy",
|
||||||
'boto3',
|
"boto3",
|
||||||
'filelock',
|
"filelock",
|
||||||
'requests',
|
"requests",
|
||||||
'tqdm',
|
"tqdm",
|
||||||
'regex != 2019.12.17',
|
"regex != 2019.12.17",
|
||||||
'sentencepiece',
|
"sentencepiece",
|
||||||
'sacremoses'],
|
"sacremoses",
|
||||||
extras_require=extras,
|
|
||||||
scripts=[
|
|
||||||
'transformers-cli'
|
|
||||||
],
|
],
|
||||||
|
extras_require=extras,
|
||||||
|
scripts=["transformers-cli"],
|
||||||
# python_requires='>=3.5.0',
|
# python_requires='>=3.5.0',
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Intended Audience :: Science/Research',
|
"Intended Audience :: Science/Research",
|
||||||
'License :: OSI Approved :: Apache Software License',
|
"License :: OSI Approved :: Apache Software License",
|
||||||
'Programming Language :: Python :: 3',
|
"Programming Language :: Python :: 3",
|
||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ import glob
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
TensorDataset)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -35,19 +34,32 @@ except:
|
|||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (
|
||||||
BertForQuestionAnswering, BertTokenizer,
|
WEIGHTS_NAME,
|
||||||
XLMConfig, XLMForQuestionAnswering,
|
BertConfig,
|
||||||
XLMTokenizer, XLNetConfig,
|
BertForQuestionAnswering,
|
||||||
|
BertTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
DistilBertConfig,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
from utils_squad import (
|
||||||
RawResult, write_predictions,
|
read_squad_examples,
|
||||||
RawResultExtended, write_predictions_extended)
|
convert_examples_to_features,
|
||||||
|
RawResult,
|
||||||
|
write_predictions,
|
||||||
|
RawResultExtended,
|
||||||
|
write_predictions_extended,
|
||||||
|
)
|
||||||
|
|
||||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||||
# You can remove it from the dependencies if you are using this script outside of the library
|
# You can remove it from the dependencies if you are using this script outside of the library
|
||||||
@@ -56,16 +68,18 @@ from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
ALL_MODELS = sum(
|
||||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -73,9 +87,11 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def to_list(tensor):
|
def to_list(tensor):
|
||||||
return tensor.detach().cpu().tolist()
|
return tensor.detach().cpu().tolist()
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@@ -92,13 +108,18 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer and schedule (linear warmup and decay)
|
# Prepare optimizer and schedule (linear warmup and decay)
|
||||||
no_decay = ['bias', 'LayerNorm.weight']
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
{
|
||||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||||
|
)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -112,17 +133,21 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
output_device=args.local_rank,
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
find_unused_parameters=True)
|
)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
logger.info(
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
|
args.train_batch_size
|
||||||
|
* args.gradient_accumulation_steps
|
||||||
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
@@ -136,15 +161,16 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1],
|
"input_ids": batch[0],
|
||||||
'start_positions': batch[3],
|
"attention_mask": batch[1],
|
||||||
'end_positions': batch[4]}
|
"start_positions": batch[3],
|
||||||
if args.model_type != 'distilbert':
|
"end_positions": batch[4],
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
}
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type != "distilbert":
|
||||||
inputs.update({'cls_index': batch[5],
|
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||||
'p_mask': batch[6]})
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
|
|
||||||
@@ -173,22 +199,26 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if (
|
||||||
|
args.local_rank == -1 and args.evaluate_during_training
|
||||||
|
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
@@ -224,32 +254,31 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
model.eval()
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||||
'attention_mask': batch[1]
|
if args.model_type != "distilbert":
|
||||||
}
|
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
|
||||||
if args.model_type != 'distilbert':
|
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({'cls_index': batch[4],
|
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||||
'p_mask': batch[5]})
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
for i, example_index in enumerate(example_indices):
|
for i, example_index in enumerate(example_indices):
|
||||||
eval_feature = features[example_index.item()]
|
eval_feature = features[example_index.item()]
|
||||||
unique_id = int(eval_feature.unique_id)
|
unique_id = int(eval_feature.unique_id)
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
# XLNet uses a more complex post-processing procedure
|
# XLNet uses a more complex post-processing procedure
|
||||||
result = RawResultExtended(unique_id = unique_id,
|
result = RawResultExtended(
|
||||||
|
unique_id=unique_id,
|
||||||
start_top_log_probs=to_list(outputs[0][i]),
|
start_top_log_probs=to_list(outputs[0][i]),
|
||||||
start_top_index=to_list(outputs[1][i]),
|
start_top_index=to_list(outputs[1][i]),
|
||||||
end_top_log_probs=to_list(outputs[2][i]),
|
end_top_log_probs=to_list(outputs[2][i]),
|
||||||
end_top_index=to_list(outputs[3][i]),
|
end_top_index=to_list(outputs[3][i]),
|
||||||
cls_logits = to_list(outputs[4][i]))
|
cls_logits=to_list(outputs[4][i]),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = RawResult(unique_id = unique_id,
|
result = RawResult(
|
||||||
start_logits = to_list(outputs[0][i]),
|
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
|
||||||
end_logits = to_list(outputs[1][i]))
|
)
|
||||||
all_results.append(result)
|
all_results.append(result)
|
||||||
|
|
||||||
# Compute predictions
|
# Compute predictions
|
||||||
@@ -260,23 +289,44 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
else:
|
else:
|
||||||
output_null_log_odds_file = None
|
output_null_log_odds_file = None
|
||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
# XLNet uses a more complex post-processing procedure
|
# XLNet uses a more complex post-processing procedure
|
||||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
write_predictions_extended(
|
||||||
args.max_answer_length, output_prediction_file,
|
examples,
|
||||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
features,
|
||||||
model.config.start_n_top, model.config.end_n_top,
|
all_results,
|
||||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
args.n_best_size,
|
||||||
|
args.max_answer_length,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
args.predict_file,
|
||||||
|
model.config.start_n_top,
|
||||||
|
model.config.end_n_top,
|
||||||
|
args.version_2_with_negative,
|
||||||
|
tokenizer,
|
||||||
|
args.verbose_logging,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
write_predictions(examples, features, all_results, args.n_best_size,
|
write_predictions(
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
examples,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
features,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
all_results,
|
||||||
|
args.n_best_size,
|
||||||
|
args.max_answer_length,
|
||||||
|
args.do_lower_case,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
args.verbose_logging,
|
||||||
|
args.version_2_with_negative,
|
||||||
|
args.null_score_diff_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate with the official SQuAD script
|
# Evaluate with the official SQuAD script
|
||||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
evaluate_options = EVAL_OPTS(
|
||||||
pred_file=output_prediction_file,
|
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
|
||||||
na_prob_file=output_null_log_odds_file)
|
)
|
||||||
results = evaluate_on_squad(evaluate_options)
|
results = evaluate_on_squad(evaluate_options)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -287,24 +337,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
input_file = args.predict_file if evaluate else args.train_file
|
input_file = args.predict_file if evaluate else args.train_file
|
||||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
cached_features_file = os.path.join(
|
||||||
'dev' if evaluate else 'train',
|
os.path.dirname(input_file),
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
"cached_{}_{}_{}".format(
|
||||||
str(args.max_seq_length)))
|
"dev" if evaluate else "train",
|
||||||
|
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||||
|
str(args.max_seq_length),
|
||||||
|
),
|
||||||
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_file)
|
logger.info("Creating features from dataset file at %s", input_file)
|
||||||
examples = read_squad_examples(input_file=input_file,
|
examples = read_squad_examples(
|
||||||
is_training=not evaluate,
|
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
|
||||||
version_2_with_negative=args.version_2_with_negative)
|
)
|
||||||
features = convert_examples_to_features(examples=examples,
|
features = convert_examples_to_features(
|
||||||
|
examples=examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate)
|
is_training=not evaluate,
|
||||||
|
)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
@@ -320,14 +376,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||||
if evaluate:
|
if evaluate:
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
dataset = TensorDataset(
|
||||||
all_example_index, all_cls_index, all_p_mask)
|
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
dataset = TensorDataset(
|
||||||
all_start_positions, all_end_positions,
|
all_input_ids,
|
||||||
all_cls_index, all_p_mask)
|
all_input_mask,
|
||||||
|
all_segment_ids,
|
||||||
|
all_start_positions,
|
||||||
|
all_end_positions,
|
||||||
|
all_cls_index,
|
||||||
|
all_p_mask,
|
||||||
|
)
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
return dataset, examples, features
|
return dataset, examples, features
|
||||||
@@ -338,109 +401,190 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
parser.add_argument(
|
||||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
)
|
||||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
parser.add_argument(
|
||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
"--predict_file",
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
default=None,
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
type=str,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
required=True,
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
||||||
help="The output directory where the model checkpoints and predictions will be written.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||||||
|
)
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument(
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
)
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
parser.add_argument(
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
"--tokenizer_name",
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
parser.add_argument(
|
||||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
"--version_2_with_negative",
|
||||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
action="store_true",
|
||||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
help="If true, the SQuAD examples contain some that do not have an answer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--null_score_diff_threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="If null_score - best_non_null is greater than the threshold predict null.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
parser.add_argument(
|
||||||
|
"--max_seq_length",
|
||||||
|
default=384,
|
||||||
|
type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
"longer than this will be truncated, and sequences shorter than this will be padded.",
|
||||||
parser.add_argument("--doc_stride", default=128, type=int,
|
)
|
||||||
help="When splitting up a long document into chunks, how much stride to take between chunks.")
|
parser.add_argument(
|
||||||
parser.add_argument("--max_query_length", default=64, type=int,
|
"--doc_stride",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_query_length",
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||||
"be truncated to this length.")
|
"be truncated to this length.",
|
||||||
parser.add_argument("--do_train", action='store_true',
|
)
|
||||||
help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||||
help="Whether to run eval on the dev set.")
|
parser.add_argument(
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||||
help="Rul evaluation during training at each logging step.")
|
)
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument(
|
||||||
help="Set this flag if you are using an uncased model.")
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
help="Batch size per GPU/CPU for training.")
|
parser.add_argument(
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
help="The initial learning rate for Adam.")
|
parser.add_argument(
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
"--gradient_accumulation_steps",
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
type=int,
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
default=1,
|
||||||
help="Weight deay if we apply some.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
)
|
||||||
help="Epsilon for Adam optimizer.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument(
|
||||||
help="Total number of training epochs to perform.")
|
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
)
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
parser.add_argument(
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
"--max_steps",
|
||||||
help="Linear warmup over warmup_steps.")
|
default=-1,
|
||||||
parser.add_argument("--n_best_size", default=20, type=int,
|
type=int,
|
||||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_best_size",
|
||||||
|
default=20,
|
||||||
|
type=int,
|
||||||
|
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_answer_length",
|
||||||
|
default=30,
|
||||||
|
type=int,
|
||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another.")
|
"and end predictions are not conditioned on one another.",
|
||||||
parser.add_argument("--verbose_logging", action='store_true',
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose_logging",
|
||||||
|
action="store_true",
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help="If true, all of the warnings related to data processing will be printed. "
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--logging_steps', type=int, default=50,
|
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||||
help="Log every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument(
|
||||||
help="Save checkpoint every X updates steps.")
|
"--eval_all_checkpoints",
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
action="store_true",
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
)
|
||||||
help="Whether not to use CUDA when available")
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument(
|
||||||
help="Overwrite the content of the output directory")
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||||
parser.add_argument('--overwrite_cache', action='store_true',
|
)
|
||||||
help="Overwrite the cached training and evaluation sets")
|
parser.add_argument(
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
help="random seed for initialization")
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument("--local_rank", type=int, default=-1,
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||||
help="local_rank for distributed training on gpus")
|
parser.add_argument(
|
||||||
parser.add_argument('--fp16', action='store_true',
|
"--fp16",
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
action="store_true",
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_opt_level",
|
||||||
|
type=str,
|
||||||
|
default="O1",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html")
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
)
|
||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if (
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
os.path.exists(args.output_dir)
|
||||||
|
and os.listdir(args.output_dir)
|
||||||
|
and args.do_train
|
||||||
|
and not args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||||
|
args.output_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
import ptvsd
|
import ptvsd
|
||||||
|
|
||||||
print("Waiting for debugger attach")
|
print("Waiting for debugger attach")
|
||||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
ptvsd.wait_for_attach()
|
ptvsd.wait_for_attach()
|
||||||
@@ -452,16 +596,24 @@ def main():
|
|||||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
args.local_rank,
|
||||||
|
device,
|
||||||
|
args.n_gpu,
|
||||||
|
bool(args.local_rank != -1),
|
||||||
|
args.fp16,
|
||||||
|
)
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
@@ -472,15 +624,21 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
config = config_class.from_pretrained(
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
args.config_name if args.config_name else args.model_name_or_path,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
do_lower_case=args.do_lower_case,
|
do_lower_case=args.do_lower_case,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path,
|
)
|
||||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
model = model_class.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
@@ -495,7 +653,8 @@ def main():
|
|||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
import apex
|
import apex
|
||||||
apex.amp.register_half_function(torch, 'einsum')
|
|
||||||
|
apex.amp.register_half_function(torch, "einsum")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
|
||||||
@@ -505,7 +664,6 @@ def main():
|
|||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Save the trained model and the tokenizer
|
# Save the trained model and the tokenizer
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
@@ -515,39 +673,42 @@ def main():
|
|||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(
|
||||||
|
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
||||||
|
)
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||||
|
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||||
|
|
||||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
logger.info("Results: {}".format(results))
|
logger.info("Results: {}".format(results))
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 XXX. All rights reserved.
|
# Copyright 2018 XXX. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -37,14 +36,16 @@ class SquadExample(object):
|
|||||||
For examples without an answer, the start and end position are -1.
|
For examples without an answer, the start and end position are -1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
qas_id,
|
qas_id,
|
||||||
question_text,
|
question_text,
|
||||||
doc_tokens,
|
doc_tokens,
|
||||||
orig_answer_text=None,
|
orig_answer_text=None,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None,
|
end_position=None,
|
||||||
is_impossible=None):
|
is_impossible=None,
|
||||||
|
):
|
||||||
self.qas_id = qas_id
|
self.qas_id = qas_id
|
||||||
self.question_text = question_text
|
self.question_text = question_text
|
||||||
self.doc_tokens = doc_tokens
|
self.doc_tokens = doc_tokens
|
||||||
@@ -59,8 +60,7 @@ class SquadExample(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
s = ""
|
s = ""
|
||||||
s += "qas_id: %s" % (self.qas_id)
|
s += "qas_id: %s" % (self.qas_id)
|
||||||
s += ", question_text: %s" % (
|
s += ", question_text: %s" % (self.question_text)
|
||||||
self.question_text)
|
|
||||||
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||||
if self.start_position:
|
if self.start_position:
|
||||||
s += ", start_position: %d" % (self.start_position)
|
s += ", start_position: %d" % (self.start_position)
|
||||||
@@ -74,7 +74,8 @@ class SquadExample(object):
|
|||||||
class InputFeatures(object):
|
class InputFeatures(object):
|
||||||
"""A single set of features of data."""
|
"""A single set of features of data."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
unique_id,
|
unique_id,
|
||||||
example_index,
|
example_index,
|
||||||
doc_span_index,
|
doc_span_index,
|
||||||
@@ -89,7 +90,8 @@ class InputFeatures(object):
|
|||||||
paragraph_len,
|
paragraph_len,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None,
|
end_position=None,
|
||||||
is_impossible=None):
|
is_impossible=None,
|
||||||
|
):
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.example_index = example_index
|
self.example_index = example_index
|
||||||
self.doc_span_index = doc_span_index
|
self.doc_span_index = doc_span_index
|
||||||
@@ -109,7 +111,7 @@ class InputFeatures(object):
|
|||||||
|
|
||||||
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||||
"""Read a SQuAD json file into a list of SquadExample."""
|
"""Read a SQuAD json file into a list of SquadExample."""
|
||||||
with open(input_file, "r", encoding='utf-8') as reader:
|
with open(input_file, "r", encoding="utf-8") as reader:
|
||||||
input_data = json.load(reader)["data"]
|
input_data = json.load(reader)["data"]
|
||||||
|
|
||||||
def is_whitespace(c):
|
def is_whitespace(c):
|
||||||
@@ -146,8 +148,7 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
|||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
is_impossible = qa["is_impossible"]
|
is_impossible = qa["is_impossible"]
|
||||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||||
raise ValueError(
|
raise ValueError("For training, each question should have exactly 1 answer.")
|
||||||
"For training, each question should have exactly 1 answer.")
|
|
||||||
if not is_impossible:
|
if not is_impossible:
|
||||||
answer = qa["answers"][0]
|
answer = qa["answers"][0]
|
||||||
orig_answer_text = answer["text"]
|
orig_answer_text = answer["text"]
|
||||||
@@ -162,11 +163,9 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
|||||||
# Note that this means for training mode, every example is NOT
|
# Note that this means for training mode, every example is NOT
|
||||||
# guaranteed to be preserved.
|
# guaranteed to be preserved.
|
||||||
actual_text = " ".join(doc_tokens[start_position : (end_position + 1)])
|
actual_text = " ".join(doc_tokens[start_position : (end_position + 1)])
|
||||||
cleaned_answer_text = " ".join(
|
cleaned_answer_text = " ".join(whitespace_tokenize(orig_answer_text))
|
||||||
whitespace_tokenize(orig_answer_text))
|
|
||||||
if actual_text.find(cleaned_answer_text) == -1:
|
if actual_text.find(cleaned_answer_text) == -1:
|
||||||
logger.warning("Could not find answer: '%s' vs. '%s'",
|
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||||
actual_text, cleaned_answer_text)
|
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
start_position = -1
|
start_position = -1
|
||||||
@@ -180,18 +179,29 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
|||||||
orig_answer_text=orig_answer_text,
|
orig_answer_text=orig_answer_text,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position,
|
end_position=end_position,
|
||||||
is_impossible=is_impossible)
|
is_impossible=is_impossible,
|
||||||
|
)
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
def convert_examples_to_features(
|
||||||
doc_stride, max_query_length, is_training,
|
examples,
|
||||||
|
tokenizer,
|
||||||
|
max_seq_length,
|
||||||
|
doc_stride,
|
||||||
|
max_query_length,
|
||||||
|
is_training,
|
||||||
cls_token_at_end=False,
|
cls_token_at_end=False,
|
||||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
cls_token="[CLS]",
|
||||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
sep_token="[SEP]",
|
||||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
pad_token=0,
|
||||||
mask_padding_with_zero=True):
|
sequence_a_segment_id=0,
|
||||||
|
sequence_b_segment_id=1,
|
||||||
|
cls_token_segment_id=0,
|
||||||
|
pad_token_segment_id=0,
|
||||||
|
mask_padding_with_zero=True,
|
||||||
|
):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
unique_id = 1000000000
|
unique_id = 1000000000
|
||||||
@@ -232,8 +242,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
else:
|
else:
|
||||||
tok_end_position = len(all_doc_tokens) - 1
|
tok_end_position = len(all_doc_tokens) - 1
|
||||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.orig_answer_text
|
||||||
example.orig_answer_text)
|
)
|
||||||
|
|
||||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||||
@@ -241,8 +251,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
# We can have documents that are longer than the maximum sequence length.
|
# We can have documents that are longer than the maximum sequence length.
|
||||||
# To deal with this we do a sliding window approach, where we take chunks
|
# To deal with this we do a sliding window approach, where we take chunks
|
||||||
# of the up to our max length with a stride of `doc_stride`.
|
# of the up to our max length with a stride of `doc_stride`.
|
||||||
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
_DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) # pylint: disable=invalid-name
|
||||||
"DocSpan", ["start", "length"])
|
|
||||||
doc_spans = []
|
doc_spans = []
|
||||||
start_offset = 0
|
start_offset = 0
|
||||||
while start_offset < len(all_doc_tokens):
|
while start_offset < len(all_doc_tokens):
|
||||||
@@ -287,8 +296,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
split_token_index = doc_span.start + i
|
split_token_index = doc_span.start + i
|
||||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||||
|
|
||||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
|
||||||
split_token_index)
|
|
||||||
token_is_max_context[len(tokens)] = is_max_context
|
token_is_max_context[len(tokens)] = is_max_context
|
||||||
tokens.append(all_doc_tokens[split_token_index])
|
tokens.append(all_doc_tokens[split_token_index])
|
||||||
segment_ids.append(sequence_b_segment_id)
|
segment_ids.append(sequence_b_segment_id)
|
||||||
@@ -333,8 +341,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
doc_start = doc_span.start
|
doc_start = doc_span.start
|
||||||
doc_end = doc_span.start + doc_span.length - 1
|
doc_end = doc_span.start + doc_span.length - 1
|
||||||
out_of_span = False
|
out_of_span = False
|
||||||
if not (tok_start_position >= doc_start and
|
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
||||||
tok_end_position <= doc_end):
|
|
||||||
out_of_span = True
|
out_of_span = True
|
||||||
if out_of_span:
|
if out_of_span:
|
||||||
start_position = 0
|
start_position = 0
|
||||||
@@ -355,24 +362,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
logger.info("example_index: %s" % (example_index))
|
logger.info("example_index: %s" % (example_index))
|
||||||
logger.info("doc_span_index: %s" % (doc_span_index))
|
logger.info("doc_span_index: %s" % (doc_span_index))
|
||||||
logger.info("tokens: %s" % " ".join(tokens))
|
logger.info("tokens: %s" % " ".join(tokens))
|
||||||
logger.info("token_to_orig_map: %s" % " ".join([
|
logger.info(
|
||||||
"%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
|
"token_to_orig_map: %s" % " ".join(["%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])
|
||||||
logger.info("token_is_max_context: %s" % " ".join([
|
)
|
||||||
"%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
|
logger.info(
|
||||||
]))
|
"token_is_max_context: %s"
|
||||||
|
% " ".join(["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()])
|
||||||
|
)
|
||||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||||
logger.info(
|
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||||
logger.info(
|
|
||||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
|
||||||
if is_training and span_is_impossible:
|
if is_training and span_is_impossible:
|
||||||
logger.info("impossible example")
|
logger.info("impossible example")
|
||||||
if is_training and not span_is_impossible:
|
if is_training and not span_is_impossible:
|
||||||
answer_text = " ".join(tokens[start_position : (end_position + 1)])
|
answer_text = " ".join(tokens[start_position : (end_position + 1)])
|
||||||
logger.info("start_position: %d" % (start_position))
|
logger.info("start_position: %d" % (start_position))
|
||||||
logger.info("end_position: %d" % (end_position))
|
logger.info("end_position: %d" % (end_position))
|
||||||
logger.info(
|
logger.info("answer: %s" % (answer_text))
|
||||||
"answer: %s" % (answer_text))
|
|
||||||
|
|
||||||
features.append(
|
features.append(
|
||||||
InputFeatures(
|
InputFeatures(
|
||||||
@@ -390,14 +396,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
paragraph_len=paragraph_len,
|
paragraph_len=paragraph_len,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position,
|
end_position=end_position,
|
||||||
is_impossible=span_is_impossible))
|
is_impossible=span_is_impossible,
|
||||||
|
)
|
||||||
|
)
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
||||||
orig_answer_text):
|
|
||||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||||
|
|
||||||
# The SQuAD annotations are character based. We first project them to
|
# The SQuAD annotations are character based. We first project them to
|
||||||
@@ -470,13 +477,23 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
|||||||
return cur_span_index == best_span_index
|
return cur_span_index == best_span_index
|
||||||
|
|
||||||
|
|
||||||
RawResult = collections.namedtuple("RawResult",
|
RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
|
||||||
["unique_id", "start_logits", "end_logits"])
|
|
||||||
|
|
||||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|
||||||
max_answer_length, do_lower_case, output_prediction_file,
|
def write_predictions(
|
||||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
all_examples,
|
||||||
version_2_with_negative, null_score_diff_threshold):
|
all_features,
|
||||||
|
all_results,
|
||||||
|
n_best_size,
|
||||||
|
max_answer_length,
|
||||||
|
do_lower_case,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
verbose_logging,
|
||||||
|
version_2_with_negative,
|
||||||
|
null_score_diff_threshold,
|
||||||
|
):
|
||||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
@@ -490,8 +507,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
unique_id_to_result[result.unique_id] = result
|
unique_id_to_result[result.unique_id] = result
|
||||||
|
|
||||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"PrelimPrediction",
|
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
||||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
)
|
||||||
|
|
||||||
all_predictions = collections.OrderedDict()
|
all_predictions = collections.OrderedDict()
|
||||||
all_nbest_json = collections.OrderedDict()
|
all_nbest_json = collections.OrderedDict()
|
||||||
@@ -544,7 +561,9 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
start_index=start_index,
|
start_index=start_index,
|
||||||
end_index=end_index,
|
end_index=end_index,
|
||||||
start_logit=result.start_logits[start_index],
|
start_logit=result.start_logits[start_index],
|
||||||
end_logit=result.end_logits[end_index]))
|
end_logit=result.end_logits[end_index],
|
||||||
|
)
|
||||||
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
prelim_predictions.append(
|
prelim_predictions.append(
|
||||||
_PrelimPrediction(
|
_PrelimPrediction(
|
||||||
@@ -552,14 +571,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
start_index=0,
|
start_index=0,
|
||||||
end_index=0,
|
end_index=0,
|
||||||
start_logit=null_start_logit,
|
start_logit=null_start_logit,
|
||||||
end_logit=null_end_logit))
|
end_logit=null_end_logit,
|
||||||
prelim_predictions = sorted(
|
)
|
||||||
prelim_predictions,
|
)
|
||||||
key=lambda x: (x.start_logit + x.end_logit),
|
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
||||||
|
)
|
||||||
|
|
||||||
seen_predictions = {}
|
seen_predictions = {}
|
||||||
nbest = []
|
nbest = []
|
||||||
@@ -592,31 +611,21 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
final_text = ""
|
final_text = ""
|
||||||
seen_predictions[final_text] = True
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
||||||
_NbestPrediction(
|
|
||||||
text=final_text,
|
|
||||||
start_logit=pred.start_logit,
|
|
||||||
end_logit=pred.end_logit))
|
|
||||||
# if we didn't include the empty option in the n-best, include it
|
# if we didn't include the empty option in the n-best, include it
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
if "" not in seen_predictions:
|
if "" not in seen_predictions:
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
||||||
_NbestPrediction(
|
|
||||||
text="",
|
|
||||||
start_logit=null_start_logit,
|
|
||||||
end_logit=null_end_logit))
|
|
||||||
|
|
||||||
# In very rare edge cases we could only have single null prediction.
|
# In very rare edge cases we could only have single null prediction.
|
||||||
# So we just create a nonce prediction in this case to avoid failure.
|
# So we just create a nonce prediction in this case to avoid failure.
|
||||||
if len(nbest) == 1:
|
if len(nbest) == 1:
|
||||||
nbest.insert(0,
|
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
|
||||||
|
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if not nbest:
|
if not nbest:
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
|
||||||
|
|
||||||
assert len(nbest) >= 1
|
assert len(nbest) >= 1
|
||||||
|
|
||||||
@@ -645,8 +654,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||||
else:
|
else:
|
||||||
# predict "" iff the null score - the score of best non-null > threshold
|
# predict "" iff the null score - the score of best non-null > threshold
|
||||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
||||||
best_non_null_entry.end_logit)
|
|
||||||
scores_diff_json[example.qas_id] = score_diff
|
scores_diff_json[example.qas_id] = score_diff
|
||||||
if score_diff > null_score_diff_threshold:
|
if score_diff > null_score_diff_threshold:
|
||||||
all_predictions[example.qas_id] = ""
|
all_predictions[example.qas_id] = ""
|
||||||
@@ -668,29 +676,40 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
|
|
||||||
|
|
||||||
# For XLNet (and XLM which uses the same head)
|
# For XLNet (and XLM which uses the same head)
|
||||||
RawResultExtended = collections.namedtuple("RawResultExtended",
|
RawResultExtended = collections.namedtuple(
|
||||||
["unique_id", "start_top_log_probs", "start_top_index",
|
"RawResultExtended",
|
||||||
"end_top_log_probs", "end_top_index", "cls_logits"])
|
["unique_id", "start_top_log_probs", "start_top_index", "end_top_log_probs", "end_top_index", "cls_logits"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
def write_predictions_extended(
|
||||||
max_answer_length, output_prediction_file,
|
all_examples,
|
||||||
|
all_features,
|
||||||
|
all_results,
|
||||||
|
n_best_size,
|
||||||
|
max_answer_length,
|
||||||
|
output_prediction_file,
|
||||||
output_nbest_file,
|
output_nbest_file,
|
||||||
output_null_log_odds_file, orig_data_file,
|
output_null_log_odds_file,
|
||||||
start_n_top, end_n_top, version_2_with_negative,
|
orig_data_file,
|
||||||
tokenizer, verbose_logging):
|
start_n_top,
|
||||||
|
end_n_top,
|
||||||
|
version_2_with_negative,
|
||||||
|
tokenizer,
|
||||||
|
verbose_logging,
|
||||||
|
):
|
||||||
""" XLNet write prediction logic (more complex than Bert's).
|
""" XLNet write prediction logic (more complex than Bert's).
|
||||||
Write final predictions to the json file and log-odds of null if needed.
|
Write final predictions to the json file and log-odds of null if needed.
|
||||||
|
|
||||||
Requires utils_squad_evaluate.py
|
Requires utils_squad_evaluate.py
|
||||||
"""
|
"""
|
||||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"PrelimPrediction",
|
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
||||||
["feature_index", "start_index", "end_index",
|
)
|
||||||
"start_log_prob", "end_log_prob"])
|
|
||||||
|
|
||||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Writing predictions to: %s", output_prediction_file)
|
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||||
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
@@ -754,12 +773,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
|||||||
start_index=start_index,
|
start_index=start_index,
|
||||||
end_index=end_index,
|
end_index=end_index,
|
||||||
start_log_prob=start_log_prob,
|
start_log_prob=start_log_prob,
|
||||||
end_log_prob=end_log_prob))
|
end_log_prob=end_log_prob,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prelim_predictions = sorted(
|
prelim_predictions = sorted(
|
||||||
prelim_predictions,
|
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
||||||
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
)
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
seen_predictions = {}
|
seen_predictions = {}
|
||||||
nbest = []
|
nbest = []
|
||||||
@@ -790,8 +810,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
|||||||
tok_text = " ".join(tok_text.split())
|
tok_text = " ".join(tok_text.split())
|
||||||
orig_text = " ".join(orig_tokens)
|
orig_text = " ".join(orig_tokens)
|
||||||
|
|
||||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, verbose_logging)
|
||||||
verbose_logging)
|
|
||||||
|
|
||||||
if final_text in seen_predictions:
|
if final_text in seen_predictions:
|
||||||
continue
|
continue
|
||||||
@@ -799,17 +818,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
|||||||
seen_predictions[final_text] = True
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(
|
||||||
_NbestPrediction(
|
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
||||||
text=final_text,
|
)
|
||||||
start_log_prob=pred.start_log_prob,
|
|
||||||
end_log_prob=pred.end_log_prob))
|
|
||||||
|
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if not nbest:
|
if not nbest:
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
||||||
_NbestPrediction(text="", start_log_prob=-1e6,
|
|
||||||
end_log_prob=-1e6))
|
|
||||||
|
|
||||||
total_scores = []
|
total_scores = []
|
||||||
best_non_null_entry = None
|
best_non_null_entry = None
|
||||||
@@ -850,7 +865,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
|
|||||||
with open(output_null_log_odds_file, "w") as writer:
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
with open(orig_data_file, "r", encoding="utf-8") as reader:
|
||||||
orig_data = json.load(reader)["data"]
|
orig_data = json.load(reader)["data"]
|
||||||
|
|
||||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||||
@@ -914,8 +929,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
|||||||
start_position = tok_text.find(pred_text)
|
start_position = tok_text.find(pred_text)
|
||||||
if start_position == -1:
|
if start_position == -1:
|
||||||
if verbose_logging:
|
if verbose_logging:
|
||||||
logger.info(
|
logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||||
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
|
||||||
return orig_text
|
return orig_text
|
||||||
end_position = start_position + len(pred_text) - 1
|
end_position = start_position + len(pred_text) - 1
|
||||||
|
|
||||||
@@ -924,8 +938,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
|||||||
|
|
||||||
if len(orig_ns_text) != len(tok_ns_text):
|
if len(orig_ns_text) != len(tok_ns_text):
|
||||||
if verbose_logging:
|
if verbose_logging:
|
||||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
|
||||||
orig_ns_text, tok_ns_text)
|
|
||||||
return orig_text
|
return orig_text
|
||||||
|
|
||||||
# We then project the characters in `pred_text` back to `orig_text` using
|
# We then project the characters in `pred_text` back to `orig_text` using
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
|
||||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -63,7 +63,8 @@ class XxxConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=50257,
|
vocab_size=50257,
|
||||||
n_positions=1024,
|
n_positions=1024,
|
||||||
n_ctx=1024,
|
n_ctx=1024,
|
||||||
@@ -75,12 +76,13 @@ class XxxConfig(PretrainedConfig):
|
|||||||
attn_pdrop=0.1,
|
attn_pdrop=0.1,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
summary_type='cls_index',
|
summary_type="cls_index",
|
||||||
summary_use_proj=True,
|
summary_use_proj=True,
|
||||||
summary_activation=None,
|
summary_activation=None,
|
||||||
summary_proj_to_labels=True,
|
summary_proj_to_labels=True,
|
||||||
summary_first_dropout=0.1,
|
summary_first_dropout=0.1,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
super(XxxConfig, self).__init__(**kwargs)
|
super(XxxConfig, self).__init__(**kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
|
|||||||
@@ -24,8 +24,10 @@ import torch
|
|||||||
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = XxxConfig.from_json_file(config_file)
|
config = XxxConfig.from_json_file(config_file)
|
||||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
"--config_file",
|
||||||
parser.add_argument("--config_file",
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The config json file corresponding to the pre-trained model. \n"
|
help="The config json file corresponding to the pre-trained model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
parser.add_argument("--pytorch_dump_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
required = True,
|
)
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||||
args.config_file,
|
|
||||||
args.pytorch_dump_path)
|
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
|||||||
# for the pretrained weights provided with the models
|
# for the pretrained weights provided with the models
|
||||||
####################################################
|
####################################################
|
||||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
|
||||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
|
||||||
}
|
}
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
@@ -69,9 +69,9 @@ TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
class TFXxxLayer(tf.keras.layers.Layer):
|
class TFXxxLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super(TFXxxLayer, self).__init__(**kwargs)
|
super(TFXxxLayer, self).__init__(**kwargs)
|
||||||
self.attention = TFXxxAttention(config, name='attention')
|
self.attention = TFXxxAttention(config, name="attention")
|
||||||
self.intermediate = TFXxxIntermediate(config, name='intermediate')
|
self.intermediate = TFXxxIntermediate(config, name="intermediate")
|
||||||
self.transformer_output = TFXxxOutput(config, name='output')
|
self.transformer_output = TFXxxOutput(config, name="output")
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, attention_mask, head_mask = inputs
|
hidden_states, attention_mask, head_mask = inputs
|
||||||
@@ -98,7 +98,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
|||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(
|
||||||
|
self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False
|
||||||
|
):
|
||||||
# We allow three types of multi-inputs:
|
# We allow three types of multi-inputs:
|
||||||
# - traditional keyword arguments in the call method
|
# - traditional keyword arguments in the call method
|
||||||
# - all the arguments provided as a dict in the first positional argument of call
|
# - all the arguments provided as a dict in the first positional argument of call
|
||||||
@@ -113,11 +115,11 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
assert len(inputs) <= 5, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
assert len(inputs) <= 5, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
@@ -175,6 +177,7 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
|
|||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = XxxConfig
|
config_class = XxxConfig
|
||||||
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
@@ -263,8 +266,12 @@ XXX_INPUTS_DOCSTRING = r"""
|
|||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class TFXxxModel(TFXxxPreTrainedModel):
|
class TFXxxModel(TFXxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -297,17 +304,19 @@ class TFXxxModel(TFXxxPreTrainedModel):
|
|||||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.transformer(inputs, **kwargs)
|
outputs = self.transformer(inputs, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
@add_start_docstrings(
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||||
|
)
|
||||||
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -333,26 +342,30 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel):
|
|||||||
prediction_scores = outputs[0]
|
prediction_scores = outputs[0]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||||
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name='mlm')
|
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name="mlm")
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.transformer(inputs, **kwargs)
|
outputs = self.transformer(inputs, **kwargs)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
|
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
||||||
|
|
||||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -378,22 +391,23 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
|||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
self.classifier = tf.keras.layers.Dense(
|
||||||
kernel_initializer=get_initializer(config.initializer_range),
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
name='classifier')
|
)
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.transformer(inputs, **kwargs)
|
outputs = self.transformer(inputs, **kwargs)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
|
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
@@ -401,9 +415,12 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
|
|||||||
return outputs # logits, (hidden_states), (attentions)
|
return outputs # logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Xxx Model with a token classification head on top (a linear layer on top of
|
||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -429,22 +446,23 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
|||||||
scores = outputs[0]
|
scores = outputs[0]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
self.classifier = tf.keras.layers.Dense(
|
||||||
kernel_initializer=get_initializer(config.initializer_range),
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
name='classifier')
|
)
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.transformer(inputs, **kwargs)
|
outputs = self.transformer(inputs, **kwargs)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
|
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
@@ -452,9 +470,12 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
|
|||||||
return outputs # scores, (hidden_states), (attentions)
|
return outputs # scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
@add_start_docstrings(
|
||||||
|
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -482,14 +503,15 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
|
|||||||
start_scores, end_scores = outputs[:2]
|
start_scores, end_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
self.transformer = TFXxxMainLayer(config, name='transformer')
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
||||||
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
|
self.qa_outputs = tf.keras.layers.Dense(
|
||||||
kernel_initializer=get_initializer(config.initializer_range),
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||||
name='qa_outputs')
|
)
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.transformer(inputs, **kwargs)
|
outputs = self.transformer(inputs, **kwargs)
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
|||||||
# for the pretrained weights provided with the models
|
# for the pretrained weights provided with the models
|
||||||
####################################################
|
####################################################
|
||||||
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
|
||||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
@@ -60,8 +60,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
logger.error(
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||||
@@ -76,7 +78,7 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
arrays.append(array)
|
arrays.append(array)
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
name = name.split('/')
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||||
@@ -84,18 +86,18 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||||
l = re.split(r'_(\d+)', m_name)
|
l = re.split(r"_(\d+)", m_name)
|
||||||
else:
|
else:
|
||||||
l = [m_name]
|
l = [m_name]
|
||||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
if l[0] == "kernel" or l[0] == "gamma":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
elif l[0] == "output_bias" or l[0] == "beta":
|
||||||
pointer = getattr(pointer, 'bias')
|
pointer = getattr(pointer, "bias")
|
||||||
elif l[0] == 'output_weights':
|
elif l[0] == "output_weights":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif l[0] == 'squad':
|
elif l[0] == "squad":
|
||||||
pointer = getattr(pointer, 'classifier')
|
pointer = getattr(pointer, "classifier")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
pointer = getattr(pointer, l[0])
|
pointer = getattr(pointer, l[0])
|
||||||
@@ -105,9 +107,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
if len(l) >= 2:
|
if len(l) >= 2:
|
||||||
num = int(l[1])
|
num = int(l[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
if m_name[-11:] == '_embeddings':
|
if m_name[-11:] == "_embeddings":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif m_name == 'kernel':
|
elif m_name == "kernel":
|
||||||
array = np.transpose(array)
|
array = np.transpose(array)
|
||||||
try:
|
try:
|
||||||
assert pointer.shape == array.shape
|
assert pointer.shape == array.shape
|
||||||
@@ -147,7 +149,6 @@ class XxxLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
# PreTrainedModel is a sub-class of torch.nn.Module
|
# PreTrainedModel is a sub-class of torch.nn.Module
|
||||||
# which take care of loading and saving pretrained weights
|
# which take care of loading and saving pretrained weights
|
||||||
@@ -161,6 +162,7 @@ class XxxPreTrainedModel(PreTrainedModel):
|
|||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = XxxConfig
|
config_class = XxxConfig
|
||||||
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_xxx
|
load_tf_weights = load_tf_weights_in_xxx
|
||||||
@@ -246,8 +248,12 @@ XXX_INPUTS_DOCSTRING = r"""
|
|||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class XxxModel(XxxPreTrainedModel):
|
class XxxModel(XxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -277,6 +283,7 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XxxModel, self).__init__(config)
|
super(XxxModel, self).__init__(config)
|
||||||
|
|
||||||
@@ -300,7 +307,15 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
):
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
@@ -342,14 +357,20 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = (
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(
|
||||||
|
dtype=next(self.parameters()).dtype
|
||||||
|
) # switch to fload if need + fp16 compatibility
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
# Replace this with your model code
|
# Replace this with your model code
|
||||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||||
@@ -357,8 +378,9 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
return outputs # sequence_output, (hidden_states), (attentions)
|
return outputs # sequence_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """,
|
@add_start_docstrings(
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||||
|
)
|
||||||
class XxxForMaskedLM(XxxPreTrainedModel):
|
class XxxForMaskedLM(XxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -389,6 +411,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
|||||||
loss, prediction_scores = outputs[:2]
|
loss, prediction_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XxxForMaskedLM, self).__init__(config)
|
super(XxxForMaskedLM, self).__init__(config)
|
||||||
|
|
||||||
@@ -400,15 +423,25 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
masked_lm_labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
masked_lm_labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.cls(sequence_output)
|
prediction_scores = self.cls(sequence_output)
|
||||||
@@ -422,9 +455,12 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
|||||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class XxxForSequenceClassification(XxxPreTrainedModel):
|
class XxxForSequenceClassification(XxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -456,6 +492,7 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XxxForSequenceClassification, self).__init__(config)
|
super(XxxForSequenceClassification, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -466,15 +503,25 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
def forward(
|
||||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@@ -496,9 +543,12 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
|||||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Xxx Model with a token classification head on top (a linear layer on top of
|
||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class XxxForTokenClassification(XxxPreTrainedModel):
|
class XxxForTokenClassification(XxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -528,6 +578,7 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XxxForTokenClassification, self).__init__(config)
|
super(XxxForTokenClassification, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -538,15 +589,25 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
def forward(
|
||||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@@ -569,9 +630,12 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
|||||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
@add_start_docstrings(
|
||||||
|
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING)
|
XXX_START_DOCSTRING,
|
||||||
|
XXX_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class XxxForQuestionAnswering(XxxPreTrainedModel):
|
class XxxForQuestionAnswering(XxxPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -613,6 +677,7 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XxxForQuestionAnswering, self).__init__(config)
|
super(XxxForQuestionAnswering, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -622,15 +687,26 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
start_positions=None, end_positions=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from __future__ import print_function
|
|||||||
import unittest
|
import unittest
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
|
||||||
from .configuration_common_test import ConfigTester
|
from .configuration_common_test import ConfigTester
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
@@ -27,23 +27,34 @@ from transformers import XxxConfig, is_tf_available
|
|||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from transformers.modeling_tf_xxx import (TFXxxModel, TFXxxForMaskedLM,
|
from transformers.modeling_tf_xxx import (
|
||||||
|
TFXxxModel,
|
||||||
|
TFXxxForMaskedLM,
|
||||||
TFXxxForSequenceClassification,
|
TFXxxForSequenceClassification,
|
||||||
TFXxxForTokenClassification,
|
TFXxxForTokenClassification,
|
||||||
TFXxxForQuestionAnswering,
|
TFXxxForQuestionAnswering,
|
||||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||||
|
|
||||||
all_model_classes = (TFXxxModel, TFXxxForMaskedLM, TFXxxForQuestionAnswering,
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
TFXxxModel,
|
||||||
|
TFXxxForMaskedLM,
|
||||||
|
TFXxxForQuestionAnswering,
|
||||||
TFXxxForSequenceClassification,
|
TFXxxForSequenceClassification,
|
||||||
TFXxxForTokenClassification) if is_tf_available() else ()
|
TFXxxForTokenClassification,
|
||||||
|
)
|
||||||
|
if is_tf_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
class TFXxxModelTester(object):
|
class TFXxxModelTester(object):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
@@ -120,15 +131,16 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
type_vocab_size=self.type_vocab_size,
|
type_vocab_size=self.type_vocab_size,
|
||||||
initializer_range=self.initializer_range)
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_xxx_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
model = TFXxxModel(config=config)
|
model = TFXxxModel(config=config)
|
||||||
inputs = {'input_ids': input_ids,
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
'attention_mask': input_mask,
|
|
||||||
'token_type_ids': token_type_ids}
|
|
||||||
sequence_output, pooled_output = model(inputs)
|
sequence_output, pooled_output = model(inputs)
|
||||||
|
|
||||||
inputs = [input_ids, input_mask]
|
inputs = [input_ids, input_mask]
|
||||||
@@ -141,78 +153,74 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
"pooled_output": pooled_output.numpy(),
|
"pooled_output": pooled_output.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["sequence_output"].shape),
|
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
)
|
||||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_masked_lm(
|
||||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
model = TFXxxForMaskedLM(config=config)
|
model = TFXxxForMaskedLM(config=config)
|
||||||
inputs = {'input_ids': input_ids,
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
'attention_mask': input_mask,
|
(prediction_scores,) = model(inputs)
|
||||||
'token_type_ids': token_type_ids}
|
|
||||||
prediction_scores, = model(inputs)
|
|
||||||
result = {
|
result = {
|
||||||
"prediction_scores": prediction_scores.numpy(),
|
"prediction_scores": prediction_scores.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["prediction_scores"].shape),
|
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
[self.batch_size, self.seq_length, self.vocab_size])
|
)
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_sequence_classification(
|
||||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = TFXxxForSequenceClassification(config=config)
|
model = TFXxxForSequenceClassification(config=config)
|
||||||
inputs = {'input_ids': input_ids,
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
'attention_mask': input_mask,
|
(logits,) = model(inputs)
|
||||||
'token_type_ids': token_type_ids}
|
|
||||||
logits, = model(inputs)
|
|
||||||
result = {
|
result = {
|
||||||
"logits": logits.numpy(),
|
"logits": logits.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||||
list(result["logits"].shape),
|
|
||||||
[self.batch_size, self.num_labels])
|
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_token_classification(
|
||||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = TFXxxForTokenClassification(config=config)
|
model = TFXxxForTokenClassification(config=config)
|
||||||
inputs = {'input_ids': input_ids,
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
'attention_mask': input_mask,
|
(logits,) = model(inputs)
|
||||||
'token_type_ids': token_type_ids}
|
|
||||||
logits, = model(inputs)
|
|
||||||
result = {
|
result = {
|
||||||
"logits": logits.numpy(),
|
"logits": logits.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["logits"].shape),
|
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||||
[self.batch_size, self.seq_length, self.num_labels])
|
)
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_question_answering(
|
||||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
model = TFXxxForQuestionAnswering(config=config)
|
model = TFXxxForQuestionAnswering(config=config)
|
||||||
inputs = {'input_ids': input_ids,
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
'attention_mask': input_mask,
|
|
||||||
'token_type_ids': token_type_ids}
|
|
||||||
start_logits, end_logits = model(inputs)
|
start_logits, end_logits = model(inputs)
|
||||||
result = {
|
result = {
|
||||||
"start_logits": start_logits.numpy(),
|
"start_logits": start_logits.numpy(),
|
||||||
"end_logits": end_logits.numpy(),
|
"end_logits": end_logits.numpy(),
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
list(result["start_logits"].shape),
|
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
[self.batch_size, self.seq_length])
|
|
||||||
self.parent.assertListEqual(
|
|
||||||
list(result["end_logits"].shape),
|
|
||||||
[self.batch_size, self.seq_length])
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, token_type_ids, input_mask,
|
(
|
||||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
config,
|
||||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -244,9 +252,10 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in ['xxx-base-uncased']:
|
for model_name in ["xxx-base-uncased"]:
|
||||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -20,28 +20,37 @@ import unittest
|
|||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
from .modeling_common_test import CommonTestCases, ids_tensor
|
||||||
from .configuration_common_test import ConfigTester
|
from .configuration_common_test import ConfigTester
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
from transformers import (
|
||||||
XxxForNextSentencePrediction, XxxForPreTraining,
|
XxxConfig,
|
||||||
XxxForQuestionAnswering, XxxForSequenceClassification,
|
XxxModel,
|
||||||
XxxForTokenClassification, XxxForMultipleChoice)
|
XxxForMaskedLM,
|
||||||
|
XxxForNextSentencePrediction,
|
||||||
|
XxxForPreTraining,
|
||||||
|
XxxForQuestionAnswering,
|
||||||
|
XxxForSequenceClassification,
|
||||||
|
XxxForTokenClassification,
|
||||||
|
XxxForMultipleChoice,
|
||||||
|
)
|
||||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class XxxModelTest(CommonTestCases.CommonModelTester):
|
class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||||
|
|
||||||
all_model_classes = (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering,
|
all_model_classes = (
|
||||||
XxxForSequenceClassification,
|
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
|
||||||
XxxForTokenClassification) if is_torch_available() else ()
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
class XxxModelTester(object):
|
class XxxModelTester(object):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
@@ -118,16 +127,17 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
|||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
type_vocab_size=self.type_vocab_size,
|
type_vocab_size=self.type_vocab_size,
|
||||||
initializer_range=self.initializer_range)
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
def check_loss_output(self, result):
|
def check_loss_output(self, result):
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
list(result["loss"].size()),
|
|
||||||
[])
|
|
||||||
|
|
||||||
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_xxx_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
model = XxxModel(config=config)
|
model = XxxModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -140,83 +150,98 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
|||||||
"pooled_output": pooled_output,
|
"pooled_output": pooled_output,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["sequence_output"].size()),
|
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
)
|
||||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_masked_lm(
|
||||||
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
model = XxxForMaskedLM(config=config)
|
model = XxxForMaskedLM(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
loss, prediction_scores = model(
|
||||||
|
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"prediction_scores": prediction_scores,
|
"prediction_scores": prediction_scores,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["prediction_scores"].size()),
|
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
[self.batch_size, self.seq_length, self.vocab_size])
|
)
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_question_answering(
|
||||||
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
model = XxxForQuestionAnswering(config=config)
|
model = XxxForQuestionAnswering(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
loss, start_logits, end_logits = model(
|
||||||
start_positions=sequence_labels, end_positions=sequence_labels)
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
start_positions=sequence_labels,
|
||||||
|
end_positions=sequence_labels,
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"start_logits": start_logits,
|
"start_logits": start_logits,
|
||||||
"end_logits": end_logits,
|
"end_logits": end_logits,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||||
list(result["start_logits"].size()),
|
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||||
[self.batch_size, self.seq_length])
|
|
||||||
self.parent.assertListEqual(
|
|
||||||
list(result["end_logits"].size()),
|
|
||||||
[self.batch_size, self.seq_length])
|
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_sequence_classification(
|
||||||
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = XxxForSequenceClassification(config)
|
model = XxxForSequenceClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
loss, logits = model(
|
||||||
|
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"logits": logits,
|
"logits": logits,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||||
list(result["logits"].size()),
|
|
||||||
[self.batch_size, self.num_labels])
|
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_xxx_for_token_classification(
|
||||||
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = XxxForTokenClassification(config=config)
|
model = XxxForTokenClassification(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
loss, logits = model(
|
||||||
|
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"logits": logits,
|
"logits": logits,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["logits"].size()),
|
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
|
||||||
[self.batch_size, self.seq_length, self.num_labels])
|
)
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, token_type_ids, input_mask,
|
(
|
||||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
config,
|
||||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -252,5 +277,6 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
|||||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -18,10 +18,11 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from transformers.tokenization_bert import (XxxTokenizer, VOCAB_FILES_NAMES)
|
from transformers.tokenization_bert import XxxTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import CommonTestCases
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
|
|
||||||
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||||
|
|
||||||
tokenizer_class = XxxTokenizer
|
tokenizer_class = XxxTokenizer
|
||||||
@@ -30,28 +31,39 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
super(XxxTokenizationTest, self).setUp()
|
super(XxxTokenizationTest, self).setUp()
|
||||||
|
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]",
|
||||||
"##ing", ",", "low", "lowest",
|
"[CLS]",
|
||||||
|
"[SEP]",
|
||||||
|
"want",
|
||||||
|
"##want",
|
||||||
|
"##ed",
|
||||||
|
"wa",
|
||||||
|
"un",
|
||||||
|
"runn",
|
||||||
|
"##ing",
|
||||||
|
",",
|
||||||
|
"low",
|
||||||
|
"lowest",
|
||||||
]
|
]
|
||||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = u"UNwant\u00E9d,running"
|
input_text = "UNwant\u00E9d,running"
|
||||||
output_text = u"unwanted, running"
|
output_text = "unwanted, running"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
|
||||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -34,17 +34,16 @@ logger = logging.getLogger(__name__)
|
|||||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||||
# to file names for serializing Tokenizer instances
|
# to file names for serializing Tokenizer instances
|
||||||
####################################################
|
####################################################
|
||||||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||||
# to pretrained vocabulary URL for all the model shortcut names.
|
# to pretrained vocabulary URL for all the model shortcut names.
|
||||||
####################################################
|
####################################################
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
'vocab_file':
|
"vocab_file": {
|
||||||
{
|
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
||||||
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
||||||
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,8 +51,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
# Mapping from model shortcut names to max length of inputs
|
# Mapping from model shortcut names to max length of inputs
|
||||||
####################################################
|
####################################################
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'xxx-base-uncased': 512,
|
"xxx-base-uncased": 512,
|
||||||
'xxx-large-uncased': 512,
|
"xxx-large-uncased": 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
@@ -62,8 +61,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
# To be used for checkpoint specific configurations.
|
# To be used for checkpoint specific configurations.
|
||||||
####################################################
|
####################################################
|
||||||
PRETRAINED_INIT_CONFIGURATION = {
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
'xxx-base-uncased': {'do_lower_case': True},
|
"xxx-base-uncased": {"do_lower_case": True},
|
||||||
'xxx-large-uncased': {'do_lower_case': True},
|
"xxx-large-uncased": {"do_lower_case": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -73,7 +72,7 @@ def load_vocab(vocab_file):
|
|||||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||||
tokens = reader.readlines()
|
tokens = reader.readlines()
|
||||||
for index, token in enumerate(tokens):
|
for index, token in enumerate(tokens):
|
||||||
token = token.rstrip('\n')
|
token = token.rstrip("\n")
|
||||||
vocab[token] = index
|
vocab[token] = index
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
@@ -93,9 +92,17 @@ class XxxTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True,
|
def __init__(
|
||||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
self,
|
||||||
mask_token="[MASK]", **kwargs):
|
vocab_file,
|
||||||
|
do_lower_case=True,
|
||||||
|
unk_token="[UNK]",
|
||||||
|
sep_token="[SEP]",
|
||||||
|
pad_token="[PAD]",
|
||||||
|
cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""Constructs a XxxTokenizer.
|
"""Constructs a XxxTokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -104,16 +111,22 @@ class XxxTokenizer(PreTrainedTokenizer):
|
|||||||
Whether to lower case the input
|
Whether to lower case the input
|
||||||
Only has an effect when do_basic_tokenize=True
|
Only has an effect when do_basic_tokenize=True
|
||||||
"""
|
"""
|
||||||
super(XxxTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
super(XxxTokenizer, self).__init__(
|
||||||
pad_token=pad_token, cls_token=cls_token,
|
unk_token=unk_token,
|
||||||
mask_token=mask_token, **kwargs)
|
sep_token=sep_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||||
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||||
|
|
||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
|
||||||
|
)
|
||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -142,7 +155,7 @@ class XxxTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
""" Converts a sequence of tokens (string) in a single string. """
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||||
@@ -177,8 +190,10 @@ class XxxTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
if already_has_special_tokens:
|
if already_has_special_tokens:
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
raise ValueError("You should not supply a second sequence if the provided sequence of "
|
raise ValueError(
|
||||||
"ids is already formated with special tokens for the model.")
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
|
"ids is already formated with special tokens for the model."
|
||||||
|
)
|
||||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||||
|
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
@@ -204,15 +219,17 @@ class XxxTokenizer(PreTrainedTokenizer):
|
|||||||
"""Save the tokenizer vocabulary to a directory or file."""
|
"""Save the tokenizer vocabulary to a directory or file."""
|
||||||
index = 0
|
index = 0
|
||||||
if os.path.isdir(vocab_path):
|
if os.path.isdir(vocab_path):
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
else:
|
else:
|
||||||
vocab_file = vocab_path
|
vocab_file = vocab_path
|
||||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
if index != token_index:
|
if index != token_index:
|
||||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
logger.warning(
|
||||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||||
|
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
||||||
|
)
|
||||||
index = token_index
|
index = token_index
|
||||||
writer.write(token + u'\n')
|
writer.write(token + "\n")
|
||||||
index += 1
|
index += 1
|
||||||
return (vocab_file,)
|
return (vocab_file,)
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ __version__ = "2.3.0"
|
|||||||
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
|
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
|
||||||
try:
|
try:
|
||||||
import absl.logging
|
import absl.logging
|
||||||
absl.logging.set_verbosity('info')
|
|
||||||
absl.logging.set_stderrthreshold('info')
|
absl.logging.set_verbosity("info")
|
||||||
|
absl.logging.set_stderrthreshold("info")
|
||||||
absl.logging._warn_preinit_stderr = False
|
absl.logging._warn_preinit_stderr = False
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -17,19 +18,41 @@ import logging
|
|||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (
|
||||||
cached_path, add_start_docstrings, add_end_docstrings,
|
TRANSFORMERS_CACHE,
|
||||||
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, MODEL_CARD_NAME,
|
PYTORCH_TRANSFORMERS_CACHE,
|
||||||
is_tf_available, is_torch_available)
|
PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
|
cached_path,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_end_docstrings,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
TF_WEIGHTS_NAME,
|
||||||
|
CONFIG_NAME,
|
||||||
|
MODEL_CARD_NAME,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
|
||||||
from .data import (is_sklearn_available,
|
from .data import (
|
||||||
InputExample, InputFeatures, DataProcessor,
|
is_sklearn_available,
|
||||||
|
InputExample,
|
||||||
|
InputFeatures,
|
||||||
|
DataProcessor,
|
||||||
SingleSentenceClassificationProcessor,
|
SingleSentenceClassificationProcessor,
|
||||||
glue_output_modes, glue_convert_examples_to_features,
|
glue_output_modes,
|
||||||
glue_processors, glue_tasks_num_labels,
|
glue_convert_examples_to_features,
|
||||||
xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
|
glue_processors,
|
||||||
squad_convert_examples_to_features, SquadFeatures,
|
glue_tasks_num_labels,
|
||||||
SquadExample, SquadV1Processor, SquadV2Processor)
|
xnli_output_modes,
|
||||||
|
xnli_processors,
|
||||||
|
xnli_tasks_num_labels,
|
||||||
|
squad_convert_examples_to_features,
|
||||||
|
SquadFeatures,
|
||||||
|
SquadExample,
|
||||||
|
SquadV1Processor,
|
||||||
|
SquadV2Processor,
|
||||||
|
)
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .data import glue_compute_metrics, xnli_compute_metrics
|
from .data import glue_compute_metrics, xnli_compute_metrics
|
||||||
@@ -38,12 +61,12 @@ if is_sklearn_available():
|
|||||||
from .modelcard import ModelCard
|
from .modelcard import ModelCard
|
||||||
|
|
||||||
# Tokenizers
|
# Tokenizers
|
||||||
from .tokenization_utils import (PreTrainedTokenizer)
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_auto import AutoTokenizer
|
from .tokenization_auto import AutoTokenizer
|
||||||
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer
|
from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLCorpus
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
from .tokenization_ctrl import CTRLTokenizer
|
from .tokenization_ctrl import CTRLTokenizer
|
||||||
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
||||||
@@ -75,143 +98,281 @@ from .configuration_mmbt import MMBTConfig
|
|||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
|
||||||
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
from .modeling_auto import (
|
||||||
AutoModelWithLMHead, AutoModelForTokenClassification, ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
AutoModel,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoModelWithLMHead,
|
||||||
|
AutoModelForTokenClassification,
|
||||||
|
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
|
from .modeling_bert import (
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertPreTrainedModel,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertModel,
|
||||||
BertForTokenClassification, BertForQuestionAnswering,
|
BertForPreTraining,
|
||||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
BertForMaskedLM,
|
||||||
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
BertForNextSentencePrediction,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
BertForSequenceClassification,
|
||||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
BertForMultipleChoice,
|
||||||
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
BertForTokenClassification,
|
||||||
|
BertForQuestionAnswering,
|
||||||
|
load_tf_weights_in_bert,
|
||||||
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_openai import (
|
||||||
|
OpenAIGPTPreTrainedModel,
|
||||||
|
OpenAIGPTModel,
|
||||||
|
OpenAIGPTLMHeadModel,
|
||||||
|
OpenAIGPTDoubleHeadsModel,
|
||||||
|
load_tf_weights_in_openai_gpt,
|
||||||
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_transfo_xl import (
|
||||||
|
TransfoXLPreTrainedModel,
|
||||||
|
TransfoXLModel,
|
||||||
|
TransfoXLLMHeadModel,
|
||||||
AdaptiveEmbedding,
|
AdaptiveEmbedding,
|
||||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_transfo_xl,
|
||||||
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
)
|
||||||
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_gpt2 import (
|
||||||
from .modeling_ctrl import (CTRLPreTrainedModel, CTRLModel,
|
GPT2PreTrainedModel,
|
||||||
CTRLLMHeadModel,
|
GPT2Model,
|
||||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
GPT2LMHeadModel,
|
||||||
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
GPT2DoubleHeadsModel,
|
||||||
XLNetForSequenceClassification, XLNetForTokenClassification,
|
load_tf_weights_in_gpt2,
|
||||||
XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple,
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLNetForQuestionAnswering, load_tf_weights_in_xlnet,
|
)
|
||||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
from .modeling_xlnet import (
|
||||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
XLNetPreTrainedModel,
|
||||||
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
|
XLNetModel,
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLNetLMHeadModel,
|
||||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
XLNetForSequenceClassification,
|
||||||
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
XLNetForTokenClassification,
|
||||||
RobertaForTokenClassification, RobertaForQuestionAnswering,
|
XLNetForMultipleChoice,
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLNetForQuestionAnsweringSimple,
|
||||||
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
XLNetForQuestionAnswering,
|
||||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
load_tf_weights_in_xlnet,
|
||||||
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_xlm import (
|
||||||
|
XLMPreTrainedModel,
|
||||||
|
XLMModel,
|
||||||
|
XLMWithLMHeadModel,
|
||||||
|
XLMForSequenceClassification,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLMForQuestionAnsweringSimple,
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_roberta import (
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaModel,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
RobertaForMultipleChoice,
|
||||||
|
RobertaForTokenClassification,
|
||||||
|
RobertaForQuestionAnswering,
|
||||||
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_distilbert import (
|
||||||
|
DistilBertPreTrainedModel,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DistilBertModel,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
DistilBertForTokenClassification,
|
DistilBertForTokenClassification,
|
||||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
from .modeling_camembert import (CamembertForMaskedLM, CamembertModel,
|
)
|
||||||
CamembertForSequenceClassification, CamembertForMultipleChoice,
|
from .modeling_camembert import (
|
||||||
|
CamembertForMaskedLM,
|
||||||
|
CamembertModel,
|
||||||
|
CamembertForSequenceClassification,
|
||||||
|
CamembertForMultipleChoice,
|
||||||
CamembertForTokenClassification,
|
CamembertForTokenClassification,
|
||||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||||
from .modeling_t5 import (T5PreTrainedModel, T5Model, T5WithLMHeadModel,
|
from .modeling_t5 import (
|
||||||
|
T5PreTrainedModel,
|
||||||
|
T5Model,
|
||||||
|
T5WithLMHeadModel,
|
||||||
load_tf_weights_in_t5,
|
load_tf_weights_in_t5,
|
||||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
from .modeling_albert import (AlbertPreTrainedModel, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification,
|
)
|
||||||
|
from .modeling_albert import (
|
||||||
|
AlbertPreTrainedModel,
|
||||||
|
AlbertModel,
|
||||||
|
AlbertForMaskedLM,
|
||||||
|
AlbertForSequenceClassification,
|
||||||
AlbertForQuestionAnswering,
|
AlbertForQuestionAnswering,
|
||||||
load_tf_weights_in_albert, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_albert,
|
||||||
from .modeling_xlm_roberta import (XLMRobertaForMaskedLM, XLMRobertaModel, XLMRobertaForMultipleChoice,
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification)
|
)
|
||||||
|
from .modeling_xlm_roberta import (
|
||||||
|
XLMRobertaForMaskedLM,
|
||||||
|
XLMRobertaModel,
|
||||||
|
XLMRobertaForMultipleChoice,
|
||||||
|
XLMRobertaForSequenceClassification,
|
||||||
|
XLMRobertaForTokenClassification,
|
||||||
|
)
|
||||||
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
|
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
|
from .optimization import (
|
||||||
get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup)
|
AdamW,
|
||||||
|
get_constant_schedule,
|
||||||
|
get_constant_schedule_with_warmup,
|
||||||
|
get_cosine_schedule_with_warmup,
|
||||||
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TensorFlow
|
# TensorFlow
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
|
||||||
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
from .modeling_tf_auto import (
|
||||||
TFAutoModelWithLMHead, TFAutoModelForTokenClassification, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TFAutoModel,
|
||||||
|
TFAutoModelForSequenceClassification,
|
||||||
|
TFAutoModelForQuestionAnswering,
|
||||||
|
TFAutoModelWithLMHead,
|
||||||
|
TFAutoModelForTokenClassification,
|
||||||
|
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings,
|
from .modeling_tf_bert import (
|
||||||
TFBertModel, TFBertForPreTraining,
|
TFBertPreTrainedModel,
|
||||||
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
TFBertMainLayer,
|
||||||
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
TFBertEmbeddings,
|
||||||
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
TFBertModel,
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TFBertForPreTraining,
|
||||||
|
TFBertForMaskedLM,
|
||||||
|
TFBertForNextSentencePrediction,
|
||||||
|
TFBertForSequenceClassification,
|
||||||
|
TFBertForMultipleChoice,
|
||||||
|
TFBertForTokenClassification,
|
||||||
|
TFBertForQuestionAnswering,
|
||||||
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
|
from .modeling_tf_gpt2 import (
|
||||||
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
TFGPT2PreTrainedModel,
|
||||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TFGPT2MainLayer,
|
||||||
|
TFGPT2Model,
|
||||||
|
TFGPT2LMHeadModel,
|
||||||
|
TFGPT2DoubleHeadsModel,
|
||||||
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
|
from .modeling_tf_openai import (
|
||||||
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
|
TFOpenAIGPTPreTrainedModel,
|
||||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TFOpenAIGPTMainLayer,
|
||||||
|
TFOpenAIGPTModel,
|
||||||
|
TFOpenAIGPTLMHeadModel,
|
||||||
|
TFOpenAIGPTDoubleHeadsModel,
|
||||||
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
from .modeling_tf_transfo_xl import (
|
||||||
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
TFTransfoXLPreTrainedModel,
|
||||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TFTransfoXLMainLayer,
|
||||||
|
TFTransfoXLModel,
|
||||||
|
TFTransfoXLLMHeadModel,
|
||||||
|
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
|
from .modeling_tf_xlnet import (
|
||||||
TFXLNetModel, TFXLNetLMHeadModel,
|
TFXLNetPreTrainedModel,
|
||||||
|
TFXLNetMainLayer,
|
||||||
|
TFXLNetModel,
|
||||||
|
TFXLNetLMHeadModel,
|
||||||
TFXLNetForSequenceClassification,
|
TFXLNetForSequenceClassification,
|
||||||
TFXLNetForTokenClassification,
|
TFXLNetForTokenClassification,
|
||||||
TFXLNetForQuestionAnsweringSimple,
|
TFXLNetForQuestionAnsweringSimple,
|
||||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
|
from .modeling_tf_xlm import (
|
||||||
TFXLMModel, TFXLMWithLMHeadModel,
|
TFXLMPreTrainedModel,
|
||||||
|
TFXLMMainLayer,
|
||||||
|
TFXLMModel,
|
||||||
|
TFXLMWithLMHeadModel,
|
||||||
TFXLMForSequenceClassification,
|
TFXLMForSequenceClassification,
|
||||||
TFXLMForQuestionAnsweringSimple,
|
TFXLMForQuestionAnsweringSimple,
|
||||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
|
from .modeling_tf_roberta import (
|
||||||
TFRobertaModel, TFRobertaForMaskedLM,
|
TFRobertaPreTrainedModel,
|
||||||
|
TFRobertaMainLayer,
|
||||||
|
TFRobertaModel,
|
||||||
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
from .modeling_tf_distilbert import (
|
||||||
TFDistilBertModel, TFDistilBertForMaskedLM,
|
TFDistilBertPreTrainedModel,
|
||||||
|
TFDistilBertMainLayer,
|
||||||
|
TFDistilBertModel,
|
||||||
|
TFDistilBertForMaskedLM,
|
||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
TFDistilBertForTokenClassification,
|
TFDistilBertForTokenClassification,
|
||||||
TFDistilBertForQuestionAnswering,
|
TFDistilBertForQuestionAnswering,
|
||||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
|
from .modeling_tf_ctrl import (
|
||||||
|
TFCTRLPreTrainedModel,
|
||||||
|
TFCTRLModel,
|
||||||
TFCTRLLMHeadModel,
|
TFCTRLLMHeadModel,
|
||||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_albert import (TFAlbertPreTrainedModel, TFAlbertModel, TFAlbertForMaskedLM,
|
from .modeling_tf_albert import (
|
||||||
|
TFAlbertPreTrainedModel,
|
||||||
|
TFAlbertModel,
|
||||||
|
TFAlbertForMaskedLM,
|
||||||
TFAlbertForSequenceClassification,
|
TFAlbertForSequenceClassification,
|
||||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_t5 import (TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel,
|
from .modeling_tf_t5 import TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
from .optimization_tf import (WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator)
|
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
||||||
|
|
||||||
# TF 2.0 <=> PyTorch conversion utilities
|
# TF 2.0 <=> PyTorch conversion utilities
|
||||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
from .modeling_tf_pytorch_utils import (
|
||||||
|
convert_tf_weight_name_to_pt_weight_name,
|
||||||
load_pytorch_checkpoint_in_tf2_model,
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
load_pytorch_weights_in_tf2_model,
|
load_pytorch_weights_in_tf2_model,
|
||||||
load_pytorch_model_in_tf2_model,
|
load_pytorch_model_in_tf2_model,
|
||||||
load_tf2_checkpoint_in_pytorch_model,
|
load_tf2_checkpoint_in_pytorch_model,
|
||||||
load_tf2_weights_in_pytorch_model,
|
load_tf2_weights_in_pytorch_model,
|
||||||
load_tf2_model_in_pytorch_model)
|
load_tf2_model_in_pytorch_model,
|
||||||
|
)
|
||||||
|
|
||||||
# Pipelines
|
# Pipelines
|
||||||
from .pipelines import pipeline, PipelineDataFormat, CsvPipelineDataFormat, JsonPipelineDataFormat, PipedPipelineDataFormat, \
|
from .pipelines import (
|
||||||
Pipeline, FeatureExtractionPipeline, QuestionAnsweringPipeline, NerPipeline, TextClassificationPipeline
|
pipeline,
|
||||||
|
PipelineDataFormat,
|
||||||
|
CsvPipelineDataFormat,
|
||||||
|
JsonPipelineDataFormat,
|
||||||
|
PipedPipelineDataFormat,
|
||||||
|
Pipeline,
|
||||||
|
FeatureExtractionPipeline,
|
||||||
|
QuestionAnsweringPipeline,
|
||||||
|
NerPipeline,
|
||||||
|
TextClassificationPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
logger.warning(
|
||||||
|
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
||||||
"Models won't be available and only tokenizers, configuration"
|
"Models won't be available and only tokenizers, configuration"
|
||||||
"and file/data utilities can be used.")
|
"and file/data utilities can be used."
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,16 +1,21 @@
|
|||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
|
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
|
||||||
print(
|
print(
|
||||||
"First argument to `transformers` command line interface should be one of: \n"
|
"First argument to `transformers` command line interface should be one of: \n"
|
||||||
">> convert serve train predict")
|
">> convert serve train predict"
|
||||||
|
)
|
||||||
if sys.argv[1] == "convert":
|
if sys.argv[1] == "convert":
|
||||||
from transformers.commands import convert
|
from transformers.commands import convert
|
||||||
|
|
||||||
convert(sys.argv)
|
convert(sys.argv)
|
||||||
elif sys.argv[1] == "train":
|
elif sys.argv[1] == "train":
|
||||||
from transformers.commands import train
|
from transformers.commands import train
|
||||||
|
|
||||||
train(sys.argv)
|
train(sys.argv)
|
||||||
elif sys.argv[1] == "serve":
|
elif sys.argv[1] == "serve":
|
||||||
pass
|
pass
|
||||||
@@ -19,7 +24,6 @@ def main():
|
|||||||
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
|
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
|
||||||
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
||||||
|
|
||||||
|
|
||||||
# # Register commands
|
# # Register commands
|
||||||
# ServeCommand.register_subcommand(commands_parser)
|
# ServeCommand.register_subcommand(commands_parser)
|
||||||
|
|
||||||
@@ -33,5 +37,6 @@ def main():
|
|||||||
# service = args.func(args)
|
# service = args.func(args)
|
||||||
# service.run()
|
# service.run()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformersCLICommand(ABC):
|
class BaseTransformersCLICommand(ABC):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ def convert_command_factory(args: Namespace):
|
|||||||
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
||||||
:return: ServeCommand
|
:return: ServeCommand
|
||||||
"""
|
"""
|
||||||
return ConvertCommand(args.model_type, args.tf_checkpoint, args.pytorch_dump_output,
|
return ConvertCommand(
|
||||||
args.config, args.finetuning_task_name)
|
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConvertCommand(BaseTransformersCLICommand):
|
class ConvertCommand(BaseTransformersCLICommand):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_subcommand(parser: ArgumentParser):
|
def register_subcommand(parser: ArgumentParser):
|
||||||
"""
|
"""
|
||||||
@@ -24,25 +24,39 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
:param parser: Root parser to register command-specific arguments
|
:param parser: Root parser to register command-specific arguments
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
train_parser = parser.add_parser('convert', help="CLI tool to run convert model from original "
|
train_parser = parser.add_parser(
|
||||||
"author checkpoints to Transformesr PyTorch checkpoints.")
|
"convert",
|
||||||
train_parser.add_argument('--model_type', type=str, required=True,
|
help="CLI tool to run convert model from original "
|
||||||
help='Model\'s type.')
|
"author checkpoints to Transformesr PyTorch checkpoints.",
|
||||||
train_parser.add_argument('--tf_checkpoint', type=str, required=True,
|
)
|
||||||
help='TensorFlow checkpoint path or folder.')
|
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
|
||||||
train_parser.add_argument('--pytorch_dump_output', type=str, required=True,
|
train_parser.add_argument(
|
||||||
help='Path to the PyTorch savd model output.')
|
"--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
|
||||||
train_parser.add_argument('--config', type=str, default="",
|
)
|
||||||
help='Configuration file path or folder.')
|
train_parser.add_argument(
|
||||||
train_parser.add_argument('--finetuning_task_name', type=str, default=None,
|
"--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
|
||||||
help='Optional fine-tuning task name if the TF model was a finetuned model.')
|
)
|
||||||
|
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
|
||||||
|
train_parser.add_argument(
|
||||||
|
"--finetuning_task_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Optional fine-tuning task name if the TF model was a finetuned model.",
|
||||||
|
)
|
||||||
train_parser.set_defaults(func=convert_command_factory)
|
train_parser.set_defaults(func=convert_command_factory)
|
||||||
|
|
||||||
def __init__(self, model_type: str, tf_checkpoint: str, pytorch_dump_output: str,
|
def __init__(
|
||||||
config: str, finetuning_task_name: str, *args):
|
self,
|
||||||
self._logger = getLogger('transformers-cli/converting')
|
model_type: str,
|
||||||
|
tf_checkpoint: str,
|
||||||
|
pytorch_dump_output: str,
|
||||||
|
config: str,
|
||||||
|
finetuning_task_name: str,
|
||||||
|
*args
|
||||||
|
):
|
||||||
|
self._logger = getLogger("transformers-cli/converting")
|
||||||
|
|
||||||
self._logger.info('Loading model {}'.format(model_type))
|
self._logger.info("Loading model {}".format(model_type))
|
||||||
self._model_type = model_type
|
self._model_type = model_type
|
||||||
self._tf_checkpoint = tf_checkpoint
|
self._tf_checkpoint = tf_checkpoint
|
||||||
self._pytorch_dump_output = pytorch_dump_output
|
self._pytorch_dump_output = pytorch_dump_output
|
||||||
@@ -52,63 +66,80 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
def run(self):
|
def run(self):
|
||||||
if self._model_type == "bert":
|
if self._model_type == "bert":
|
||||||
try:
|
try:
|
||||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||||
|
convert_tf_checkpoint_to_pytorch,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
msg = (
|
||||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "gpt":
|
elif self._model_type == "gpt":
|
||||||
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint,
|
convert_openai_checkpoint_to_pytorch,
|
||||||
self._config,
|
)
|
||||||
self._pytorch_dump_output)
|
|
||||||
|
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "transfo_xl":
|
elif self._model_type == "transfo_xl":
|
||||||
try:
|
try:
|
||||||
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||||
|
convert_transfo_xl_checkpoint_to_pytorch,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
msg = (
|
||||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
if 'ckpt' in self._tf_checkpoint.lower():
|
if "ckpt" in self._tf_checkpoint.lower():
|
||||||
TF_CHECKPOINT = self._tf_checkpoint
|
TF_CHECKPOINT = self._tf_checkpoint
|
||||||
TF_DATASET_FILE = ""
|
TF_DATASET_FILE = ""
|
||||||
else:
|
else:
|
||||||
TF_DATASET_FILE = self._tf_checkpoint
|
TF_DATASET_FILE = self._tf_checkpoint
|
||||||
TF_CHECKPOINT = ""
|
TF_CHECKPOINT = ""
|
||||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT,
|
convert_transfo_xl_checkpoint_to_pytorch(
|
||||||
self._config,
|
TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
|
||||||
self._pytorch_dump_output,
|
)
|
||||||
TF_DATASET_FILE)
|
|
||||||
elif self._model_type == "gpt2":
|
elif self._model_type == "gpt2":
|
||||||
try:
|
try:
|
||||||
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||||
|
convert_gpt2_checkpoint_to_pytorch,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
msg = (
|
||||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "xlnet":
|
elif self._model_type == "xlnet":
|
||||||
try:
|
try:
|
||||||
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||||
|
convert_xlnet_checkpoint_to_pytorch,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
|
msg = (
|
||||||
"In that case, it requires TensorFlow to be installed. Please see " \
|
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions."
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
convert_xlnet_checkpoint_to_pytorch(self._tf_checkpoint,
|
convert_xlnet_checkpoint_to_pytorch(
|
||||||
self._config,
|
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||||
self._pytorch_dump_output,
|
)
|
||||||
self._finetuning_task_name)
|
|
||||||
elif self._model_type == "xlm":
|
elif self._model_type == "xlm":
|
||||||
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
|
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||||
|
convert_xlm_checkpoint_to_pytorch,
|
||||||
|
)
|
||||||
|
|
||||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -8,13 +8,16 @@ def download_command_factory(args):
|
|||||||
|
|
||||||
|
|
||||||
class DownloadCommand(BaseTransformersCLICommand):
|
class DownloadCommand(BaseTransformersCLICommand):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_subcommand(parser: ArgumentParser):
|
def register_subcommand(parser: ArgumentParser):
|
||||||
download_parser = parser.add_parser('download')
|
download_parser = parser.add_parser("download")
|
||||||
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
|
download_parser.add_argument(
|
||||||
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
|
"--cache-dir", type=str, default=None, help="Path to location to store the models"
|
||||||
download_parser.add_argument('model', type=str, help='Name of the model to download')
|
)
|
||||||
|
download_parser.add_argument(
|
||||||
|
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
|
||||||
|
)
|
||||||
|
download_parser.add_argument("model", type=str, help="Name of the model to download")
|
||||||
download_parser.set_defaults(func=download_command_factory)
|
download_parser.set_defaults(func=download_command_factory)
|
||||||
|
|
||||||
def __init__(self, model: str, cache: str, force: bool):
|
def __init__(self, model: str, cache: str, force: bool):
|
||||||
|
|||||||
@@ -10,52 +10,72 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
def try_infer_format_from_ext(path: str):
|
def try_infer_format_from_ext(path: str):
|
||||||
if not path:
|
if not path:
|
||||||
return 'pipe'
|
return "pipe"
|
||||||
|
|
||||||
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
||||||
if path.endswith(ext):
|
if path.endswith(ext):
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Unable to determine file format from file extension {}. '
|
"Unable to determine file format from file extension {}. "
|
||||||
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
"Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_command_factory(args):
|
def run_command_factory(args):
|
||||||
nlp = pipeline(task=args.task,
|
nlp = pipeline(
|
||||||
|
task=args.task,
|
||||||
model=args.model if args.model else None,
|
model=args.model if args.model else None,
|
||||||
config=args.config,
|
config=args.config,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
device=args.device)
|
device=args.device,
|
||||||
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
|
)
|
||||||
reader = PipelineDataFormat.from_str(format=format,
|
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
|
||||||
|
reader = PipelineDataFormat.from_str(
|
||||||
|
format=format,
|
||||||
output_path=args.output,
|
output_path=args.output,
|
||||||
input_path=args.input,
|
input_path=args.input,
|
||||||
column=args.column if args.column else nlp.default_input_names,
|
column=args.column if args.column else nlp.default_input_names,
|
||||||
overwrite=args.overwrite)
|
overwrite=args.overwrite,
|
||||||
|
)
|
||||||
return RunCommand(nlp, reader)
|
return RunCommand(nlp, reader)
|
||||||
|
|
||||||
|
|
||||||
class RunCommand(BaseTransformersCLICommand):
|
class RunCommand(BaseTransformersCLICommand):
|
||||||
|
|
||||||
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
||||||
self._nlp = nlp
|
self._nlp = nlp
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_subcommand(parser: ArgumentParser):
|
def register_subcommand(parser: ArgumentParser):
|
||||||
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
|
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
||||||
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
|
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
|
||||||
run_parser.add_argument('--input', type=str, help='Path to the file to use for inference')
|
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
||||||
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
|
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
||||||
run_parser.add_argument('--model', type=str, help='Name or path to the model to instantiate.')
|
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
||||||
run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.')
|
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
|
||||||
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
|
run_parser.add_argument(
|
||||||
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
|
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
|
||||||
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
|
)
|
||||||
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
run_parser.add_argument(
|
||||||
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
|
"--column",
|
||||||
|
type=str,
|
||||||
|
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
|
||||||
|
)
|
||||||
|
run_parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
type=str,
|
||||||
|
default="infer",
|
||||||
|
choices=PipelineDataFormat.SUPPORTED_FORMATS,
|
||||||
|
help="Input format to read from",
|
||||||
|
)
|
||||||
|
run_parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||||
|
)
|
||||||
|
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
|
||||||
run_parser.set_defaults(func=run_command_factory)
|
run_parser.set_defaults(func=run_command_factory)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -71,9 +91,6 @@ class RunCommand(BaseTransformersCLICommand):
|
|||||||
# Saving data
|
# Saving data
|
||||||
if self._nlp.binary_output:
|
if self._nlp.binary_output:
|
||||||
binary_path = self._reader.save_binary(outputs)
|
binary_path = self._reader.save_binary(outputs)
|
||||||
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
|
logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
|
||||||
else:
|
else:
|
||||||
self._reader.save(outputs)
|
self._reader.save(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ try:
|
|||||||
from uvicorn import run
|
from uvicorn import run
|
||||||
from fastapi import FastAPI, HTTPException, Body
|
from fastapi import FastAPI, HTTPException, Body
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
_serve_dependancies_installed = True
|
_serve_dependancies_installed = True
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
BaseModel = object
|
BaseModel = object
|
||||||
@@ -17,18 +18,21 @@ from transformers import Pipeline
|
|||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||||
|
|
||||||
logger = logging.getLogger('transformers-cli/serving')
|
logger = logging.getLogger("transformers-cli/serving")
|
||||||
|
|
||||||
|
|
||||||
def serve_command_factory(args: Namespace):
|
def serve_command_factory(args: Namespace):
|
||||||
"""
|
"""
|
||||||
Factory function used to instantiate serving server from provided command line arguments.
|
Factory function used to instantiate serving server from provided command line arguments.
|
||||||
:return: ServeCommand
|
:return: ServeCommand
|
||||||
"""
|
"""
|
||||||
nlp = pipeline(task=args.task,
|
nlp = pipeline(
|
||||||
|
task=args.task,
|
||||||
model=args.model if args.model else None,
|
model=args.model if args.model else None,
|
||||||
config=args.config,
|
config=args.config,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
device=args.device)
|
device=args.device,
|
||||||
|
)
|
||||||
return ServeCommand(nlp, args.host, args.port)
|
return ServeCommand(nlp, args.host, args.port)
|
||||||
|
|
||||||
|
|
||||||
@@ -36,6 +40,7 @@ class ServeModelInfoResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Expose model information
|
Expose model information
|
||||||
"""
|
"""
|
||||||
|
|
||||||
infos: dict
|
infos: dict
|
||||||
|
|
||||||
|
|
||||||
@@ -43,6 +48,7 @@ class ServeTokenizeResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Tokenize result model
|
Tokenize result model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokens: List[str]
|
tokens: List[str]
|
||||||
tokens_ids: Optional[List[int]]
|
tokens_ids: Optional[List[int]]
|
||||||
|
|
||||||
@@ -51,6 +57,7 @@ class ServeDeTokenizeResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
DeTokenize result model
|
DeTokenize result model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
@@ -58,11 +65,11 @@ class ServeForwardResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Forward result model
|
Forward result model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output: Any
|
output: Any
|
||||||
|
|
||||||
|
|
||||||
class ServeCommand(BaseTransformersCLICommand):
|
class ServeCommand(BaseTransformersCLICommand):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_subcommand(parser: ArgumentParser):
|
def register_subcommand(parser: ArgumentParser):
|
||||||
"""
|
"""
|
||||||
@@ -70,14 +77,23 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
:param parser: Root parser to register command-specific arguments
|
:param parser: Root parser to register command-specific arguments
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
|
serve_parser = parser.add_parser(
|
||||||
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
|
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||||
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
|
)
|
||||||
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
|
serve_parser.add_argument(
|
||||||
serve_parser.add_argument('--model', type=str, help='Model\'s name or path to stored model.')
|
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
|
||||||
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
|
)
|
||||||
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
|
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||||
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||||
|
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||||
|
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||||
|
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||||
|
serve_parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||||
|
)
|
||||||
serve_parser.set_defaults(func=serve_command_factory)
|
serve_parser.set_defaults(func=serve_command_factory)
|
||||||
|
|
||||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
||||||
@@ -87,18 +103,22 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
if not _serve_dependancies_installed:
|
if not _serve_dependancies_installed:
|
||||||
raise ImportError("Using serve command requires FastAPI and unicorn. "
|
raise ImportError(
|
||||||
|
"Using serve command requires FastAPI and unicorn. "
|
||||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||||
"Or install FastAPI and unicorn separatly.")
|
"Or install FastAPI and unicorn separatly."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info('Serving model over {}:{}'.format(host, port))
|
logger.info("Serving model over {}:{}".format(host, port))
|
||||||
self._app = FastAPI()
|
self._app = FastAPI()
|
||||||
|
|
||||||
# Register routes
|
# Register routes
|
||||||
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
|
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
|
||||||
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
|
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
|
||||||
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
|
self._app.add_api_route(
|
||||||
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
|
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
|
||||||
|
)
|
||||||
|
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
run(self._app, host=self._host, port=self._port)
|
run(self._app, host=self._host, port=self._port)
|
||||||
@@ -122,11 +142,14 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
return ServeTokenizeResult(tokens=tokens_txt)
|
return ServeTokenizeResult(tokens=tokens_txt)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||||
|
|
||||||
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
|
def detokenize(
|
||||||
|
self,
|
||||||
|
tokens_ids: List[int] = Body(None, embed=True),
|
||||||
skip_special_tokens: bool = Body(False, embed=True),
|
skip_special_tokens: bool = Body(False, embed=True),
|
||||||
cleanup_tokenization_spaces: bool = Body(True, embed=True)):
|
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Detokenize the provided tokens ids to readable text:
|
Detokenize the provided tokens ids to readable text:
|
||||||
- **tokens_ids**: List of tokens ids
|
- **tokens_ids**: List of tokens ids
|
||||||
@@ -135,9 +158,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||||
return ServeDeTokenizeResult(model='', text=decoded_str)
|
return ServeDeTokenizeResult(model="", text=decoded_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||||
|
|
||||||
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,9 +3,12 @@ from argparse import ArgumentParser, Namespace
|
|||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers import (is_tf_available, is_torch_available,
|
from transformers import (
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
SingleSentenceClassificationProcessor as Processor)
|
SingleSentenceClassificationProcessor as Processor,
|
||||||
|
)
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||||
@@ -14,6 +17,7 @@ if not is_tf_available() and not is_torch_available():
|
|||||||
USE_XLA = False
|
USE_XLA = False
|
||||||
USE_AMP = False
|
USE_AMP = False
|
||||||
|
|
||||||
|
|
||||||
def train_command_factory(args: Namespace):
|
def train_command_factory(args: Namespace):
|
||||||
"""
|
"""
|
||||||
Factory function used to instantiate serving server from provided command line arguments.
|
Factory function used to instantiate serving server from provided command line arguments.
|
||||||
@@ -23,7 +27,6 @@ def train_command_factory(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
class TrainCommand(BaseTransformersCLICommand):
|
class TrainCommand(BaseTransformersCLICommand):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_subcommand(parser: ArgumentParser):
|
def register_subcommand(parser: ArgumentParser):
|
||||||
"""
|
"""
|
||||||
@@ -31,47 +34,54 @@ class TrainCommand(BaseTransformersCLICommand):
|
|||||||
:param parser: Root parser to register command-specific arguments
|
:param parser: Root parser to register command-specific arguments
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
|
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
|
||||||
|
|
||||||
train_parser.add_argument('--train_data', type=str, required=True,
|
train_parser.add_argument(
|
||||||
|
"--train_data",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="path to train (and optionally evaluation) dataset as a csv with "
|
help="path to train (and optionally evaluation) dataset as a csv with "
|
||||||
"tab separated labels and sentences.")
|
"tab separated labels and sentences.",
|
||||||
train_parser.add_argument('--column_label', type=int, default=0,
|
)
|
||||||
help='Column of the dataset csv file with example labels.')
|
train_parser.add_argument(
|
||||||
train_parser.add_argument('--column_text', type=int, default=1,
|
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
|
||||||
help='Column of the dataset csv file with example texts.')
|
)
|
||||||
train_parser.add_argument('--column_id', type=int, default=2,
|
train_parser.add_argument(
|
||||||
help='Column of the dataset csv file with example ids.')
|
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
|
||||||
train_parser.add_argument('--skip_first_row', action='store_true',
|
)
|
||||||
help='Skip the first row of the csv file (headers).')
|
train_parser.add_argument(
|
||||||
|
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
|
||||||
|
)
|
||||||
|
train_parser.add_argument(
|
||||||
|
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
|
||||||
|
)
|
||||||
|
|
||||||
train_parser.add_argument('--validation_data', type=str, default='',
|
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
|
||||||
help='path to validation dataset.')
|
train_parser.add_argument(
|
||||||
train_parser.add_argument('--validation_split', type=float, default=0.1,
|
"--validation_split",
|
||||||
help="if validation dataset is not provided, fraction of train dataset "
|
type=float,
|
||||||
"to use as validation dataset.")
|
default=0.1,
|
||||||
|
help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
train_parser.add_argument('--output', type=str, default='./',
|
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
|
||||||
help='path to saved the trained model.')
|
|
||||||
|
|
||||||
train_parser.add_argument('--task', type=str, default='text_classification',
|
train_parser.add_argument(
|
||||||
help='Task to train the model on.')
|
"--task", type=str, default="text_classification", help="Task to train the model on."
|
||||||
train_parser.add_argument('--model', type=str, default='bert-base-uncased',
|
)
|
||||||
help='Model\'s name or path to stored model.')
|
train_parser.add_argument(
|
||||||
train_parser.add_argument('--train_batch_size', type=int, default=32,
|
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
|
||||||
help='Batch size for training.')
|
)
|
||||||
train_parser.add_argument('--valid_batch_size', type=int, default=64,
|
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
|
||||||
help='Batch size for validation.')
|
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
|
||||||
train_parser.add_argument('--learning_rate', type=float, default=3e-5,
|
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
|
||||||
help="Learning rate.")
|
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
|
||||||
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08,
|
|
||||||
help="Epsilon for Adam optimizer.")
|
|
||||||
train_parser.set_defaults(func=train_command_factory)
|
train_parser.set_defaults(func=train_command_factory)
|
||||||
|
|
||||||
def __init__(self, args: Namespace):
|
def __init__(self, args: Namespace):
|
||||||
self.logger = getLogger('transformers-cli/training')
|
self.logger = getLogger("transformers-cli/training")
|
||||||
|
|
||||||
self.framework = 'tf' if is_tf_available() else 'torch'
|
self.framework = "tf" if is_tf_available() else "torch"
|
||||||
|
|
||||||
os.makedirs(args.output, exist_ok=True)
|
os.makedirs(args.output, exist_ok=True)
|
||||||
assert os.path.isdir(args.output)
|
assert os.path.isdir(args.output)
|
||||||
@@ -81,28 +91,32 @@ class TrainCommand(BaseTransformersCLICommand):
|
|||||||
self.column_text = args.column_text
|
self.column_text = args.column_text
|
||||||
self.column_id = args.column_id
|
self.column_id = args.column_id
|
||||||
|
|
||||||
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model))
|
self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
|
||||||
if args.task == 'text_classification':
|
if args.task == "text_classification":
|
||||||
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
||||||
elif args.task == 'token_classification':
|
elif args.task == "token_classification":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif args.task == 'question_answering':
|
elif args.task == "question_answering":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.logger.info('Loading dataset from {}'.format(args.train_data))
|
self.logger.info("Loading dataset from {}".format(args.train_data))
|
||||||
self.train_dataset = Processor.create_from_csv(args.train_data,
|
self.train_dataset = Processor.create_from_csv(
|
||||||
|
args.train_data,
|
||||||
column_label=args.column_label,
|
column_label=args.column_label,
|
||||||
column_text=args.column_text,
|
column_text=args.column_text,
|
||||||
column_id=args.column_id,
|
column_id=args.column_id,
|
||||||
skip_first_row=args.skip_first_row)
|
skip_first_row=args.skip_first_row,
|
||||||
|
)
|
||||||
self.valid_dataset = None
|
self.valid_dataset = None
|
||||||
if args.validation_data:
|
if args.validation_data:
|
||||||
self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
|
self.logger.info("Loading validation dataset from {}".format(args.validation_data))
|
||||||
self.valid_dataset = Processor.create_from_csv(args.validation_data,
|
self.valid_dataset = Processor.create_from_csv(
|
||||||
|
args.validation_data,
|
||||||
column_label=args.column_label,
|
column_label=args.column_label,
|
||||||
column_text=args.column_text,
|
column_text=args.column_text,
|
||||||
column_id=args.column_id,
|
column_id=args.column_id,
|
||||||
skip_first_row=args.skip_first_row)
|
skip_first_row=args.skip_first_row,
|
||||||
|
)
|
||||||
|
|
||||||
self.validation_split = args.validation_split
|
self.validation_split = args.validation_split
|
||||||
self.train_batch_size = args.train_batch_size
|
self.train_batch_size = args.train_batch_size
|
||||||
@@ -111,7 +125,7 @@ class TrainCommand(BaseTransformersCLICommand):
|
|||||||
self.adam_epsilon = args.adam_epsilon
|
self.adam_epsilon = args.adam_epsilon
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if self.framework == 'tf':
|
if self.framework == "tf":
|
||||||
return self.run_tf()
|
return self.run_tf()
|
||||||
return self.run_torch()
|
return self.run_torch()
|
||||||
|
|
||||||
@@ -119,13 +133,15 @@ class TrainCommand(BaseTransformersCLICommand):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def run_tf(self):
|
def run_tf(self):
|
||||||
self.pipeline.fit(self.train_dataset,
|
self.pipeline.fit(
|
||||||
|
self.train_dataset,
|
||||||
validation_data=self.valid_dataset,
|
validation_data=self.valid_dataset,
|
||||||
validation_split=self.validation_split,
|
validation_split=self.validation_split,
|
||||||
learning_rate=self.learning_rate,
|
learning_rate=self.learning_rate,
|
||||||
adam_epsilon=self.adam_epsilon,
|
adam_epsilon=self.adam_epsilon,
|
||||||
train_batch_size=self.train_batch_size,
|
train_batch_size=self.train_batch_size,
|
||||||
valid_batch_size=self.valid_batch_size)
|
valid_batch_size=self.valid_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Save trained pipeline
|
# Save trained pipeline
|
||||||
self.pipeline.save_pretrained(self.output)
|
self.pipeline.save_pretrained(self.output)
|
||||||
|
|||||||
@@ -9,28 +9,31 @@ from transformers.hf_api import HfApi, HfFolder, HTTPError
|
|||||||
class UserCommands(BaseTransformersCLICommand):
|
class UserCommands(BaseTransformersCLICommand):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_subcommand(parser: ArgumentParser):
|
def register_subcommand(parser: ArgumentParser):
|
||||||
login_parser = parser.add_parser('login')
|
login_parser = parser.add_parser("login")
|
||||||
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
||||||
whoami_parser = parser.add_parser('whoami')
|
whoami_parser = parser.add_parser("whoami")
|
||||||
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||||
logout_parser = parser.add_parser('logout')
|
logout_parser = parser.add_parser("logout")
|
||||||
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||||
list_parser = parser.add_parser('ls')
|
list_parser = parser.add_parser("ls")
|
||||||
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||||
# upload
|
# upload
|
||||||
upload_parser = parser.add_parser('upload')
|
upload_parser = parser.add_parser("upload")
|
||||||
upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.')
|
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
||||||
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.')
|
upload_parser.add_argument(
|
||||||
|
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
||||||
|
)
|
||||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ANSI:
|
class ANSI:
|
||||||
"""
|
"""
|
||||||
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_bold = u"\u001b[1m"
|
_bold = u"\u001b[1m"
|
||||||
_reset = u"\u001b[0m"
|
_reset = u"\u001b[0m"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def bold(cls, s):
|
def bold(cls, s):
|
||||||
return "{}{}{}".format(cls._bold, s, cls._reset)
|
return "{}{}{}".format(cls._bold, s, cls._reset)
|
||||||
@@ -44,14 +47,16 @@ class BaseUserCommand:
|
|||||||
|
|
||||||
class LoginCommand(BaseUserCommand):
|
class LoginCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
print("""
|
print(
|
||||||
|
"""
|
||||||
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
||||||
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||||
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
||||||
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||||
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
||||||
|
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
username = input("Username: ")
|
username = input("Username: ")
|
||||||
password = getpass()
|
password = getpass()
|
||||||
try:
|
try:
|
||||||
@@ -101,16 +106,10 @@ class ListObjsCommand(BaseUserCommand):
|
|||||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||||
lines = []
|
lines = []
|
||||||
lines.append(
|
lines.append(row_format.format(*headers))
|
||||||
row_format.format(*headers)
|
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||||
)
|
|
||||||
lines.append(
|
|
||||||
row_format.format(*["-" * w for w in col_widths])
|
|
||||||
)
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
lines.append(
|
lines.append(row_format.format(*row))
|
||||||
row_format.format(*row)
|
|
||||||
)
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -126,15 +125,8 @@ class ListObjsCommand(BaseUserCommand):
|
|||||||
if len(objs) == 0:
|
if len(objs) == 0:
|
||||||
print("No shared file yet")
|
print("No shared file yet")
|
||||||
exit()
|
exit()
|
||||||
rows = [ [
|
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
||||||
obj.filename,
|
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||||
obj.LastModified,
|
|
||||||
obj.ETag,
|
|
||||||
obj.Size
|
|
||||||
] for obj in objs ]
|
|
||||||
print(
|
|
||||||
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UploadCommand(BaseUserCommand):
|
class UploadCommand(BaseUserCommand):
|
||||||
@@ -143,13 +135,7 @@ class UploadCommand(BaseUserCommand):
|
|||||||
Recursively list all files in a folder.
|
Recursively list all files in a folder.
|
||||||
"""
|
"""
|
||||||
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
||||||
files = [
|
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
|
||||||
(
|
|
||||||
os.path.join(os.getcwd(), f.path), # filepath
|
|
||||||
f.path # filename
|
|
||||||
)
|
|
||||||
for f in entries if f.is_file()
|
|
||||||
]
|
|
||||||
for f in entries:
|
for f in entries:
|
||||||
if f.is_dir():
|
if f.is_dir():
|
||||||
files += self.walk_dir(f.path)
|
files += self.walk_dir(f.path)
|
||||||
@@ -173,22 +159,14 @@ class UploadCommand(BaseUserCommand):
|
|||||||
raise ValueError("Not a valid file or directory: {}".format(local_path))
|
raise ValueError("Not a valid file or directory: {}".format(local_path))
|
||||||
|
|
||||||
for filepath, filename in files:
|
for filepath, filename in files:
|
||||||
print(
|
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
|
||||||
"About to upload file {} to S3 under filename {}".format(
|
|
||||||
ANSI.bold(filepath), ANSI.bold(filename)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
choice = input("Proceed? [Y/n] ").lower()
|
choice = input("Proceed? [Y/n] ").lower()
|
||||||
if not (choice == "" or choice == "y" or choice == "yes"):
|
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||||
print("Abort")
|
print("Abort")
|
||||||
exit()
|
exit()
|
||||||
print(
|
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
||||||
ANSI.bold("Uploading... This might take a while if files are large")
|
|
||||||
)
|
|
||||||
for filepath, filename in files:
|
for filepath, filename in files:
|
||||||
access_url = self._api.presign_and_upload(
|
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
|
||||||
token=token, filename=filename, filepath=filepath
|
|
||||||
)
|
|
||||||
print("Your file now lives at:")
|
print("Your file now lives at:")
|
||||||
print(access_url)
|
print(access_url)
|
||||||
|
|||||||
@@ -18,16 +18,17 @@
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
||||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
||||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
||||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
||||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
|
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
|
||||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
|
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
|
||||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
|
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
|
||||||
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
|
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AlbertConfig(PretrainedConfig):
|
class AlbertConfig(PretrainedConfig):
|
||||||
"""Configuration for `AlbertModel`.
|
"""Configuration for `AlbertModel`.
|
||||||
|
|
||||||
@@ -36,7 +37,8 @@ class AlbertConfig(PretrainedConfig):
|
|||||||
|
|
||||||
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=30000,
|
vocab_size=30000,
|
||||||
embedding_size=128,
|
embedding_size=128,
|
||||||
hidden_size=4096,
|
hidden_size=4096,
|
||||||
@@ -51,7 +53,9 @@ class AlbertConfig(PretrainedConfig):
|
|||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12, **kwargs):
|
layer_norm_eps=1e-12,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""Constructs AlbertConfig.
|
"""Constructs AlbertConfig.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
|
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||||
|
(key, value)
|
||||||
for pretrained_map in [
|
for pretrained_map in [
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@@ -51,7 +52,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
|
|||||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
]
|
]
|
||||||
for key, value, in pretrained_map.items())
|
for key, value, in pretrained_map.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoConfig(object):
|
class AutoConfig(object):
|
||||||
@@ -79,37 +81,42 @@ class AutoConfig(object):
|
|||||||
- contains `ctrl` : CTRLConfig (CTRL model)
|
- contains `ctrl` : CTRLConfig (CTRL model)
|
||||||
This class cannot be instantiated using `__init__()` (throw an error).
|
This class cannot be instantiated using `__init__()` (throw an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError("AutoConfig is designed to be instantiated "
|
raise EnvironmentError(
|
||||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.")
|
"AutoConfig is designed to be instantiated "
|
||||||
|
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_model(cls, model_type, *args, **kwargs):
|
def for_model(cls, model_type, *args, **kwargs):
|
||||||
if 'distilbert' in model_type:
|
if "distilbert" in model_type:
|
||||||
return DistilBertConfig(*args, **kwargs)
|
return DistilBertConfig(*args, **kwargs)
|
||||||
elif 'roberta' in model_type:
|
elif "roberta" in model_type:
|
||||||
return RobertaConfig(*args, **kwargs)
|
return RobertaConfig(*args, **kwargs)
|
||||||
elif 'bert' in model_type:
|
elif "bert" in model_type:
|
||||||
return BertConfig(*args, **kwargs)
|
return BertConfig(*args, **kwargs)
|
||||||
elif 'openai-gpt' in model_type:
|
elif "openai-gpt" in model_type:
|
||||||
return OpenAIGPTConfig(*args, **kwargs)
|
return OpenAIGPTConfig(*args, **kwargs)
|
||||||
elif 'gpt2' in model_type:
|
elif "gpt2" in model_type:
|
||||||
return GPT2Config(*args, **kwargs)
|
return GPT2Config(*args, **kwargs)
|
||||||
elif 'transfo-xl' in model_type:
|
elif "transfo-xl" in model_type:
|
||||||
return TransfoXLConfig(*args, **kwargs)
|
return TransfoXLConfig(*args, **kwargs)
|
||||||
elif 'xlnet' in model_type:
|
elif "xlnet" in model_type:
|
||||||
return XLNetConfig(*args, **kwargs)
|
return XLNetConfig(*args, **kwargs)
|
||||||
elif 'xlm' in model_type:
|
elif "xlm" in model_type:
|
||||||
return XLMConfig(*args, **kwargs)
|
return XLMConfig(*args, **kwargs)
|
||||||
elif 'ctrl' in model_type:
|
elif "ctrl" in model_type:
|
||||||
return CTRLConfig(*args, **kwargs)
|
return CTRLConfig(*args, **kwargs)
|
||||||
elif 'albert' in model_type:
|
elif "albert" in model_type:
|
||||||
return AlbertConfig(*args, **kwargs)
|
return AlbertConfig(*args, **kwargs)
|
||||||
elif 'camembert' in model_type:
|
elif "camembert" in model_type:
|
||||||
return CamembertConfig(*args, **kwargs)
|
return CamembertConfig(*args, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type))
|
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||||
@@ -176,32 +183,36 @@ class AutoConfig(object):
|
|||||||
assert unused_kwargs == {'foo': False}
|
assert unused_kwargs == {'foo': False}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 't5' in pretrained_model_name_or_path:
|
if "t5" in pretrained_model_name_or_path:
|
||||||
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'distilbert' in pretrained_model_name_or_path:
|
elif "distilbert" in pretrained_model_name_or_path:
|
||||||
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'albert' in pretrained_model_name_or_path:
|
elif "albert" in pretrained_model_name_or_path:
|
||||||
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif "camembert" in pretrained_model_name_or_path:
|
||||||
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||||
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif "roberta" in pretrained_model_name_or_path:
|
||||||
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif "bert" in pretrained_model_name_or_path:
|
||||||
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||||
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'gpt2' in pretrained_model_name_or_path:
|
elif "gpt2" in pretrained_model_name_or_path:
|
||||||
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||||
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif "xlnet" in pretrained_model_name_or_path:
|
||||||
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif "xlm" in pretrained_model_name_or_path:
|
||||||
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'ctrl' in pretrained_model_name_or_path:
|
elif "ctrl" in pretrained_model_name_or_path:
|
||||||
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
|
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,27 +27,27 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
||||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
||||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
||||||
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
||||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
||||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
||||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
||||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
||||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
||||||
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
||||||
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
||||||
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
|
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
|
||||||
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
|
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
|
||||||
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
|
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
|
||||||
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
|
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
|
||||||
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
||||||
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -82,7 +82,8 @@ class BertConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=30522,
|
vocab_size=30522,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
@@ -95,7 +96,8 @@ class BertConfig(PretrainedConfig):
|
|||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
super(BertConfig, self).__init__(**kwargs)
|
super(BertConfig, self).__init__(**kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" CamemBERT configuration """
|
""" CamemBERT configuration """
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -25,7 +24,7 @@ from .configuration_roberta import RobertaConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
|
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
|
||||||
|
|
||||||
|
|
||||||
class CTRLConfig(PretrainedConfig):
|
class CTRLConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `CTRLModel`.
|
"""Configuration class to store the configuration of a `CTRLModel`.
|
||||||
|
|
||||||
@@ -48,6 +49,7 @@ class CTRLConfig(PretrainedConfig):
|
|||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -64,7 +66,7 @@ class CTRLConfig(PretrainedConfig):
|
|||||||
attn_pdrop=0.1,
|
attn_pdrop=0.1,
|
||||||
layer_norm_epsilon=1e-6,
|
layer_norm_epsilon=1e-6,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
summary_type='cls_index',
|
summary_type="cls_index",
|
||||||
summary_use_proj=True,
|
summary_use_proj=True,
|
||||||
summary_activation=None,
|
summary_activation=None,
|
||||||
summary_proj_to_labels=True,
|
summary_proj_to_labels=True,
|
||||||
|
|||||||
@@ -13,8 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DistilBERT model configuration """
|
""" DistilBERT model configuration """
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
@@ -26,17 +25,18 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
||||||
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
|
"distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
|
||||||
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
|
"distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
|
||||||
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
|
"distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DistilBertConfig(PretrainedConfig):
|
class DistilBertConfig(PretrainedConfig):
|
||||||
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=30522,
|
vocab_size=30522,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
sinusoidal_pos_embds=False,
|
sinusoidal_pos_embds=False,
|
||||||
@@ -46,12 +46,13 @@ class DistilBertConfig(PretrainedConfig):
|
|||||||
hidden_dim=4 * 768,
|
hidden_dim=4 * 768,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
attention_dropout=0.1,
|
attention_dropout=0.1,
|
||||||
activation='gelu',
|
activation="gelu",
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
tie_weights_=True,
|
tie_weights_=True,
|
||||||
qa_dropout=0.1,
|
qa_dropout=0.1,
|
||||||
seq_classif_dropout=0.2,
|
seq_classif_dropout=0.2,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
super(DistilBertConfig, self).__init__(**kwargs)
|
super(DistilBertConfig, self).__init__(**kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|||||||
@@ -26,11 +26,14 @@ from .configuration_utils import PretrainedConfig
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
|
||||||
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
|
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
|
||||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
|
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class GPT2Config(PretrainedConfig):
|
class GPT2Config(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `GPT2Model`.
|
"""Configuration class to store the configuration of a `GPT2Model`.
|
||||||
@@ -52,6 +55,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -67,7 +71,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
attn_pdrop=0.1,
|
attn_pdrop=0.1,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
summary_type='cls_index',
|
summary_type="cls_index",
|
||||||
summary_use_proj=True,
|
summary_use_proj=True,
|
||||||
summary_activation=None,
|
summary_activation=None,
|
||||||
summary_proj_to_labels=True,
|
summary_proj_to_labels=True,
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" MMBT configuration """
|
""" MMBT configuration """
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -31,6 +30,7 @@ class MMBTConfig(object):
|
|||||||
num_labels: Size of final Linear layer for classification.
|
num_labels: Size of final Linear layer for classification.
|
||||||
modal_hidden_size: Embedding dimension of the non-text modality encoder.
|
modal_hidden_size: Embedding dimension of the non-text modality encoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
|
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
|
||||||
self.__dict__ = config.__dict__
|
self.__dict__ = config.__dict__
|
||||||
self.modal_hidden_size = modal_hidden_size
|
self.modal_hidden_size = modal_hidden_size
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
|
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTConfig(PretrainedConfig):
|
class OpenAIGPTConfig(PretrainedConfig):
|
||||||
"""
|
"""
|
||||||
Configuration class to store the configuration of a `OpenAIGPTModel`.
|
Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||||
@@ -54,6 +55,7 @@ class OpenAIGPTConfig(PretrainedConfig):
|
|||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -71,7 +73,7 @@ class OpenAIGPTConfig(PretrainedConfig):
|
|||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
predict_special_tokens=True,
|
predict_special_tokens=True,
|
||||||
summary_type='cls_index',
|
summary_type="cls_index",
|
||||||
summary_use_proj=True,
|
summary_use_proj=True,
|
||||||
summary_activation=None,
|
summary_activation=None,
|
||||||
summary_proj_to_labels=True,
|
summary_proj_to_labels=True,
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" RoBERTa configuration """
|
""" RoBERTa configuration """
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -25,12 +24,12 @@ from .configuration_bert import BertConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
||||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
"roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
||||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
"roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
||||||
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
|
"distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
|
||||||
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
|
"roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
|
||||||
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
|
"roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,11 +27,11 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
|
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
|
||||||
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
|
"t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
|
||||||
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
|
"t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
|
||||||
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
|
"t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
|
||||||
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
|
"t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +65,8 @@ class T5Config(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=32128,
|
vocab_size=32128,
|
||||||
n_positions=512,
|
n_positions=512,
|
||||||
d_model=512,
|
d_model=512,
|
||||||
@@ -77,7 +78,8 @@ class T5Config(PretrainedConfig):
|
|||||||
dropout_rate=0.1,
|
dropout_rate=0.1,
|
||||||
layer_norm_epsilon=1e-6,
|
layer_norm_epsilon=1e-6,
|
||||||
initializer_factor=1.0,
|
initializer_factor=1.0,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
super(T5Config, self).__init__(**kwargs)
|
super(T5Config, self).__init__(**kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_positions = n_positions
|
self.n_positions = n_positions
|
||||||
|
|||||||
@@ -27,9 +27,10 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLConfig(PretrainedConfig):
|
class TransfoXLConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||||
|
|
||||||
@@ -65,9 +66,11 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
proj_init_std: parameters initialized by N(0, init_std)
|
proj_init_std: parameters initialized by N(0, init_std)
|
||||||
init_std: parameters initialized by N(0, init_std)
|
init_std: parameters initialized by N(0, init_std)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=267735,
|
vocab_size=267735,
|
||||||
cutoffs=[20000, 40000, 200000],
|
cutoffs=[20000, 40000, 200000],
|
||||||
d_model=1024,
|
d_model=1024,
|
||||||
@@ -96,7 +99,8 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
proj_init_std=0.01,
|
proj_init_std=0.01,
|
||||||
init_std=0.02,
|
init_std=0.02,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
"""Constructs TransfoXLConfig.
|
"""Constructs TransfoXLConfig.
|
||||||
"""
|
"""
|
||||||
super(TransfoXLConfig, self).__init__(**kwargs)
|
super(TransfoXLConfig, self).__init__(**kwargs)
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Configuration base class and utilities."""
|
""" Configuration base class and utilities."""
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
@@ -28,6 +27,7 @@ from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PretrainedConfig(object):
|
class PretrainedConfig(object):
|
||||||
r""" Base class for all configuration classes.
|
r""" Base class for all configuration classes.
|
||||||
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
||||||
@@ -50,36 +50,36 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Attributes with defaults
|
# Attributes with defaults
|
||||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||||
self.output_past = kwargs.pop('output_past', True) # Not used by all models
|
self.output_past = kwargs.pop("output_past", True) # Not used by all models
|
||||||
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
|
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||||
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
|
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||||
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||||
|
|
||||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||||
self.is_decoder = kwargs.pop('is_decoder', False)
|
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||||
|
|
||||||
# Parameters for sequence generation
|
# Parameters for sequence generation
|
||||||
self.max_length = kwargs.pop('max_length', 20)
|
self.max_length = kwargs.pop("max_length", 20)
|
||||||
self.do_sample = kwargs.pop('do_sample', False)
|
self.do_sample = kwargs.pop("do_sample", False)
|
||||||
self.num_beams = kwargs.pop('num_beams', 1)
|
self.num_beams = kwargs.pop("num_beams", 1)
|
||||||
self.temperature = kwargs.pop('temperature', 1.0)
|
self.temperature = kwargs.pop("temperature", 1.0)
|
||||||
self.top_k = kwargs.pop('top_k', 50)
|
self.top_k = kwargs.pop("top_k", 50)
|
||||||
self.top_p = kwargs.pop('top_p', 1.0)
|
self.top_p = kwargs.pop("top_p", 1.0)
|
||||||
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||||
self.bos_token_id = kwargs.pop('bos_token_id', 0)
|
self.bos_token_id = kwargs.pop("bos_token_id", 0)
|
||||||
self.pad_token_id = kwargs.pop('pad_token_id', 0)
|
self.pad_token_id = kwargs.pop("pad_token_id", 0)
|
||||||
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
|
self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
|
||||||
self.length_penalty = kwargs.pop('length_penalty', 1.)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
||||||
self.num_labels = kwargs.pop('num_labels', 2)
|
self.num_labels = kwargs.pop("num_labels", 2)
|
||||||
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
|
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
|
||||||
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||||
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys())))
|
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
|
||||||
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
||||||
|
|
||||||
# Additional attributes without default values
|
# Additional attributes without default values
|
||||||
@@ -94,7 +94,9 @@ class PretrainedConfig(object):
|
|||||||
""" Save a configuration object to the directory `save_directory`, so that it
|
""" Save a configuration object to the directory `save_directory`, so that it
|
||||||
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
||||||
"""
|
"""
|
||||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
assert os.path.isdir(
|
||||||
|
save_directory
|
||||||
|
), "Saving path should be a directory where the model and configuration can be saved"
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||||
@@ -153,11 +155,11 @@ class PretrainedConfig(object):
|
|||||||
assert unused_kwargs == {'foo': False}
|
assert unused_kwargs == {'foo': False}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
force_download = kwargs.pop('force_download', False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop('resume_download', False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop('proxies', None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||||
@@ -170,37 +172,48 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
|
resolved_config_file = cached_path(
|
||||||
proxies=proxies, resume_download=resume_download)
|
config_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
)
|
||||||
# Load config
|
# Load config
|
||||||
config = cls.from_json_file(resolved_config_file)
|
config = cls.from_json_file(resolved_config_file)
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||||
config_file)
|
config_file
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
msg = "Model name '{}' was not found in model name list ({}). " \
|
msg = (
|
||||||
"We assumed '{}' was a path or url to a configuration file named {} or " \
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url to a configuration file named {} or "
|
||||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
', '.join(cls.pretrained_config_archive_map.keys()),
|
", ".join(cls.pretrained_config_archive_map.keys()),
|
||||||
config_file, CONFIG_NAME)
|
config_file,
|
||||||
|
CONFIG_NAME,
|
||||||
|
)
|
||||||
|
)
|
||||||
raise EnvironmentError(msg)
|
raise EnvironmentError(msg)
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
msg = "Couldn't reach server at '{}' to download configuration file or " \
|
msg = (
|
||||||
"configuration file is not a valid JSON file. " \
|
"Couldn't reach server at '{}' to download configuration file or "
|
||||||
|
"configuration file is not a valid JSON file. "
|
||||||
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
||||||
|
)
|
||||||
raise EnvironmentError(msg)
|
raise EnvironmentError(msg)
|
||||||
|
|
||||||
if resolved_config_file == config_file:
|
if resolved_config_file == config_file:
|
||||||
logger.info("loading configuration file {}".format(config_file))
|
logger.info("loading configuration file {}".format(config_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading configuration file {} from cache at {}".format(
|
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
|
||||||
config_file, resolved_config_file))
|
|
||||||
|
|
||||||
if hasattr(config, 'pruned_heads'):
|
if hasattr(config, "pruned_heads"):
|
||||||
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
||||||
|
|
||||||
# Update config with kwargs if needed
|
# Update config with kwargs if needed
|
||||||
@@ -226,7 +239,7 @@ class PretrainedConfig(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_json_file(cls, json_file):
|
def from_json_file(cls, json_file):
|
||||||
"""Constructs a `Config` from a json file of parameters."""
|
"""Constructs a `Config` from a json file of parameters."""
|
||||||
with open(json_file, "r", encoding='utf-8') as reader:
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
dict_obj = json.loads(text)
|
dict_obj = json.loads(text)
|
||||||
return cls(**dict_obj)
|
return cls(**dict_obj)
|
||||||
@@ -248,5 +261,5 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
def to_json_file(self, json_file_path):
|
def to_json_file(self, json_file_path):
|
||||||
""" Save this instance to a json file."""
|
""" Save this instance to a json file."""
|
||||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||||
writer.write(self.to_json_string())
|
writer.write(self.to_json_string())
|
||||||
|
|||||||
@@ -25,16 +25,16 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||||
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
|
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
|
||||||
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
|
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
|
||||||
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
|
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
|
||||||
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
|
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
|
||||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
|
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
|
||||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
|
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
|
||||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
|
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
|
||||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
|
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
|
||||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
|
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -78,9 +78,11 @@ class XLMConfig(PretrainedConfig):
|
|||||||
-1 means no clamping.
|
-1 means no clamping.
|
||||||
same_length: bool, whether to use the same attention length for each token.
|
same_length: bool, whether to use the same attention length for each token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=30145,
|
vocab_size=30145,
|
||||||
emb_dim=2048,
|
emb_dim=2048,
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
@@ -103,7 +105,7 @@ class XLMConfig(PretrainedConfig):
|
|||||||
unk_index=3,
|
unk_index=3,
|
||||||
mask_index=5,
|
mask_index=5,
|
||||||
is_encoder=True,
|
is_encoder=True,
|
||||||
summary_type='first',
|
summary_type="first",
|
||||||
summary_use_proj=True,
|
summary_use_proj=True,
|
||||||
summary_activation=None,
|
summary_activation=None,
|
||||||
summary_proj_to_labels=True,
|
summary_proj_to_labels=True,
|
||||||
@@ -112,7 +114,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
mask_token_id=0,
|
mask_token_id=0,
|
||||||
lang_id=0,
|
lang_id=0,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
"""
|
"""
|
||||||
super(XLMConfig, self).__init__(**kwargs)
|
super(XLMConfig, self).__init__(**kwargs)
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" XLM-RoBERTa configuration """
|
""" XLM-RoBERTa configuration """
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -25,12 +24,12 @@ from .configuration_roberta import RobertaConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
|
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
|
||||||
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
|
"xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
|
||||||
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
|
"xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
|
||||||
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
|
"xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
|
||||||
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
|
"xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
|
||||||
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
|
"xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ from .configuration_utils import PretrainedConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
||||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -69,9 +69,11 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
same_length: bool, whether to use the same attention length for each token.
|
same_length: bool, whether to use the same attention length for each token.
|
||||||
finetuning_task: name of the glue task on which the model was fine-tuned if any
|
finetuning_task: name of the glue task on which the model was fine-tuned if any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size=32000,
|
vocab_size=32000,
|
||||||
d_model=1024,
|
d_model=1024,
|
||||||
n_layer=24,
|
n_layer=24,
|
||||||
@@ -88,13 +90,14 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
bi_data=False,
|
bi_data=False,
|
||||||
clamp_len=-1,
|
clamp_len=-1,
|
||||||
same_length=False,
|
same_length=False,
|
||||||
summary_type='last',
|
summary_type="last",
|
||||||
summary_use_proj=True,
|
summary_use_proj=True,
|
||||||
summary_activation='tanh',
|
summary_activation="tanh",
|
||||||
summary_last_dropout=0.1,
|
summary_last_dropout=0.1,
|
||||||
start_n_top=5,
|
start_n_top=5,
|
||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
"""Constructs XLNetConfig.
|
"""Constructs XLNetConfig.
|
||||||
"""
|
"""
|
||||||
super(XLNetConfig, self).__init__(**kwargs)
|
super(XLNetConfig, self).__init__(**kwargs)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
|
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@@ -44,24 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
"--albert_config_file",
|
||||||
parser.add_argument("--albert_config_file",
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The config json file corresponding to the pre-trained ALBERT model. \n"
|
help="The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
parser.add_argument("--pytorch_dump_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
required = True,
|
)
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
||||||
args.albert_config_file,
|
|
||||||
args.pytorch_dump_path)
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,8 +24,10 @@ import torch
|
|||||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = BertConfig.from_json_file(bert_config_file)
|
config = BertConfig.from_json_file(bert_config_file)
|
||||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
"--bert_config_file",
|
||||||
parser.add_argument("--bert_config_file",
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
help="The config json file corresponding to the pre-trained BERT model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
parser.add_argument("--pytorch_dump_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
required = True,
|
)
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
||||||
args.bert_config_file,
|
|
||||||
args.pytorch_dump_path)
|
|
||||||
|
|||||||
@@ -41,22 +41,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
|||||||
N BertForQuestionAnswering
|
N BertForQuestionAnswering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensors_to_transpose = (
|
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||||
"dense.weight",
|
|
||||||
"attention.self.query",
|
|
||||||
"attention.self.key",
|
|
||||||
"attention.self.value"
|
|
||||||
)
|
|
||||||
|
|
||||||
var_map = (
|
var_map = (
|
||||||
('layer.', 'layer_'),
|
("layer.", "layer_"),
|
||||||
('word_embeddings.weight', 'word_embeddings'),
|
("word_embeddings.weight", "word_embeddings"),
|
||||||
('position_embeddings.weight', 'position_embeddings'),
|
("position_embeddings.weight", "position_embeddings"),
|
||||||
('token_type_embeddings.weight', 'token_type_embeddings'),
|
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||||
('.', '/'),
|
(".", "/"),
|
||||||
('LayerNorm/weight', 'LayerNorm/gamma'),
|
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||||
('LayerNorm/bias', 'LayerNorm/beta'),
|
("LayerNorm/bias", "LayerNorm/beta"),
|
||||||
('weight', 'kernel')
|
("weight", "kernel"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.isdir(ckpt_dir):
|
if not os.path.isdir(ckpt_dir):
|
||||||
@@ -67,7 +62,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
|||||||
def to_tf_var_name(name: str):
|
def to_tf_var_name(name: str):
|
||||||
for patt, repl in iter(var_map):
|
for patt, repl in iter(var_map):
|
||||||
name = name.replace(patt, repl)
|
name = name.replace(patt, repl)
|
||||||
return 'bert/{}'.format(name)
|
return "bert/{}".format(name)
|
||||||
|
|
||||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||||
@@ -94,36 +89,21 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
|||||||
|
|
||||||
def main(raw_args=None):
|
def main(raw_args=None):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_name",
|
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased")
|
||||||
type=str,
|
parser.add_argument(
|
||||||
required=True,
|
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||||
help="model name e.g. bert-base-uncased")
|
)
|
||||||
parser.add_argument("--cache_dir",
|
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||||
type=str,
|
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||||
default=None,
|
|
||||||
required=False,
|
|
||||||
help="Directory containing pytorch model")
|
|
||||||
parser.add_argument("--pytorch_model_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="/path/to/<pytorch-model-name>.bin")
|
|
||||||
parser.add_argument("--tf_cache_dir",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Directory in which to save tensorflow model")
|
|
||||||
args = parser.parse_args(raw_args)
|
args = parser.parse_args(raw_args)
|
||||||
|
|
||||||
model = BertModel.from_pretrained(
|
model = BertModel.from_pretrained(
|
||||||
pretrained_model_name_or_path=args.model_name,
|
pretrained_model_name_or_path=args.model_name,
|
||||||
state_dict=torch.load(args.pytorch_model_path),
|
state_dict=torch.load(args.pytorch_model_path),
|
||||||
cache_dir=args.cache_dir
|
cache_dir=args.cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_pytorch_checkpoint_to_tf(
|
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||||
model=model,
|
|
||||||
ckpt_dir=args.tf_cache_dir,
|
|
||||||
model_name=args.model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -21,12 +21,10 @@ from io import open
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||||
GPT2Config,
|
|
||||||
GPT2Model,
|
|
||||||
load_tf_weights_in_gpt2)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
|||||||
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
|
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||||
@@ -54,22 +52,18 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--gpt2_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--gpt2_config_file",
|
||||||
required = True,
|
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
parser.add_argument("--gpt2_config_file",
|
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
|
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
|
||||||
args.gpt2_config_file,
|
|
||||||
args.pytorch_dump_folder_path)
|
|
||||||
|
|||||||
@@ -21,12 +21,10 @@ from io import open
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||||
OpenAIGPTConfig,
|
|
||||||
OpenAIGPTModel,
|
|
||||||
load_tf_weights_in_openai_gpt)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
|
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||||
@@ -54,22 +52,24 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--openai_checkpoint_folder_path",
|
parser.add_argument(
|
||||||
|
"--openai_checkpoint_folder_path",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
help="Path to the TensorFlow checkpoint path.",
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
required = True,
|
)
|
||||||
help = "Path to the output PyTorch model.")
|
parser.add_argument(
|
||||||
parser.add_argument("--openai_config_file",
|
"--openai_config_file",
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
|
convert_openai_checkpoint_to_pytorch(
|
||||||
args.openai_config_file,
|
args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
|
||||||
args.pytorch_dump_folder_path)
|
)
|
||||||
|
|||||||
@@ -24,82 +24,270 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from transformers import is_torch_available, cached_path
|
from transformers import is_torch_available, cached_path
|
||||||
|
|
||||||
from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
from transformers import (
|
||||||
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BertConfig,
|
||||||
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFBertForPreTraining,
|
||||||
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFBertForQuestionAnswering,
|
||||||
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFBertForSequenceClassification,
|
||||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
GPT2Config,
|
||||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFGPT2LMHeadModel,
|
||||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLNetConfig,
|
||||||
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
TFXLNetLMHeadModel,
|
||||||
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
XLMConfig,
|
||||||
|
TFXLMWithLMHeadModel,
|
||||||
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
TransfoXLConfig,
|
||||||
|
TFTransfoXLLMHeadModel,
|
||||||
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
TFOpenAIGPTLMHeadModel,
|
||||||
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
RobertaConfig,
|
||||||
|
TFRobertaForMaskedLM,
|
||||||
|
TFRobertaForSequenceClassification,
|
||||||
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
DistilBertConfig,
|
||||||
|
TFDistilBertForMaskedLM,
|
||||||
|
TFDistilBertForQuestionAnswering,
|
||||||
|
TFDistilBertForSequenceClassification,
|
||||||
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
CTRLConfig,
|
||||||
|
TFCTRLLMHeadModel,
|
||||||
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
AlbertConfig,
|
||||||
|
TFAlbertForMaskedLM,
|
||||||
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
T5Config,
|
||||||
|
TFT5WithLMHeadModel,
|
||||||
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
from transformers import (
|
||||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForPreTraining,
|
||||||
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForQuestionAnswering,
|
||||||
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForSequenceClassification,
|
||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
GPT2LMHeadModel,
|
||||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLNetLMHeadModel,
|
||||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLMWithLMHeadModel,
|
||||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
TransfoXLLMHeadModel,
|
||||||
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
OpenAIGPTLMHeadModel,
|
||||||
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
CTRLLMHeadModel,
|
||||||
|
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
AlbertForMaskedLM,
|
||||||
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
T5WithLMHeadModel,
|
||||||
|
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
(
|
||||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForPreTraining,
|
||||||
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForQuestionAnswering,
|
||||||
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForSequenceClassification,
|
||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
GPT2LMHeadModel,
|
||||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLNetLMHeadModel,
|
||||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLMWithLMHeadModel,
|
||||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None, None, None,
|
TransfoXLLMHeadModel,
|
||||||
None, None,
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None,
|
OpenAIGPTLMHeadModel,
|
||||||
None, None,
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None,
|
RobertaForMaskedLM,
|
||||||
None, None,
|
RobertaForSequenceClassification,
|
||||||
None, None, None,
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None, None, None,
|
DistilBertForMaskedLM,
|
||||||
None, None,
|
DistilBertForSequenceClassification,
|
||||||
None, None,
|
DistilBertForQuestionAnswering,
|
||||||
None, None)
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
CTRLLMHeadModel,
|
||||||
|
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
AlbertForMaskedLM,
|
||||||
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
T5WithLMHeadModel,
|
||||||
|
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
) = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
"bert": (
|
||||||
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BertConfig,
|
||||||
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
TFBertForPreTraining,
|
||||||
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BertForPreTraining,
|
||||||
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
),
|
||||||
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
"bert-large-uncased-whole-word-masking-finetuned-squad": (
|
||||||
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BertConfig,
|
||||||
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
TFBertForQuestionAnswering,
|
||||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BertForQuestionAnswering,
|
||||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
),
|
||||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
"bert-large-cased-whole-word-masking-finetuned-squad": (
|
||||||
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
BertConfig,
|
||||||
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
TFBertForQuestionAnswering,
|
||||||
|
BertForQuestionAnswering,
|
||||||
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"bert-base-cased-finetuned-mrpc": (
|
||||||
|
BertConfig,
|
||||||
|
TFBertForSequenceClassification,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"gpt2": (
|
||||||
|
GPT2Config,
|
||||||
|
TFGPT2LMHeadModel,
|
||||||
|
GPT2LMHeadModel,
|
||||||
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"xlnet": (
|
||||||
|
XLNetConfig,
|
||||||
|
TFXLNetLMHeadModel,
|
||||||
|
XLNetLMHeadModel,
|
||||||
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"xlm": (
|
||||||
|
XLMConfig,
|
||||||
|
TFXLMWithLMHeadModel,
|
||||||
|
XLMWithLMHeadModel,
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"transfo-xl": (
|
||||||
|
TransfoXLConfig,
|
||||||
|
TFTransfoXLLMHeadModel,
|
||||||
|
TransfoXLLMHeadModel,
|
||||||
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"openai-gpt": (
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
TFOpenAIGPTLMHeadModel,
|
||||||
|
OpenAIGPTLMHeadModel,
|
||||||
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"roberta": (
|
||||||
|
RobertaConfig,
|
||||||
|
TFRobertaForMaskedLM,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"roberta-large-mnli": (
|
||||||
|
RobertaConfig,
|
||||||
|
TFRobertaForSequenceClassification,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"distilbert": (
|
||||||
|
DistilBertConfig,
|
||||||
|
TFDistilBertForMaskedLM,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"distilbert-base-uncased-distilled-squad": (
|
||||||
|
DistilBertConfig,
|
||||||
|
TFDistilBertForQuestionAnswering,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"distilbert-base-uncased-distilled-squad": (
|
||||||
|
DistilBertConfig,
|
||||||
|
TFDistilBertForQuestionAnswering,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"ctrl": (
|
||||||
|
CTRLConfig,
|
||||||
|
TFCTRLLMHeadModel,
|
||||||
|
CTRLLMHeadModel,
|
||||||
|
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"albert": (
|
||||||
|
AlbertConfig,
|
||||||
|
TFAlbertForMaskedLM,
|
||||||
|
AlbertForMaskedLM,
|
||||||
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
|
"t5": (
|
||||||
|
T5Config,
|
||||||
|
TFT5WithLMHeadModel,
|
||||||
|
T5WithLMHeadModel,
|
||||||
|
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
|
||||||
|
def convert_pt_checkpoint_to_tf(
|
||||||
|
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
|
||||||
|
):
|
||||||
if model_type not in MODEL_CLASSES:
|
if model_type not in MODEL_CLASSES:
|
||||||
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||||
|
|
||||||
@@ -116,17 +304,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
|
|
||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
if pytorch_checkpoint_path in aws_model_maps:
|
if pytorch_checkpoint_path in aws_model_maps:
|
||||||
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
|
pytorch_checkpoint_path = cached_path(
|
||||||
|
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
|
||||||
|
)
|
||||||
# Load PyTorch checkpoint in tf2 model:
|
# Load PyTorch checkpoint in tf2 model:
|
||||||
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||||
|
|
||||||
if compare_with_pt_model:
|
if compare_with_pt_model:
|
||||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||||
|
|
||||||
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
|
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
|
||||||
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
|
pt_model = pt_model_class.from_pretrained(
|
||||||
config=config,
|
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||||
state_dict=state_dict)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(**pt_model.dummy_inputs)
|
pto = pt_model(**pt_model.dummy_inputs)
|
||||||
@@ -139,11 +329,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
print("Save TensorFlow model to {}".format(tf_dump_path))
|
print("Save TensorFlow model to {}".format(tf_dump_path))
|
||||||
tf_model.save_weights(tf_dump_path, save_format='h5')
|
tf_model.save_weights(tf_dump_path, save_format="h5")
|
||||||
|
|
||||||
|
|
||||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
def convert_all_pt_checkpoints_to_tf(
|
||||||
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
|
args_model_type,
|
||||||
|
tf_dump_path,
|
||||||
|
model_shortcut_names_or_path=None,
|
||||||
|
config_shortcut_names_or_path=None,
|
||||||
|
compare_with_pt_model=False,
|
||||||
|
use_cached_models=False,
|
||||||
|
remove_cached_files=False,
|
||||||
|
only_convert_finetuned_models=False,
|
||||||
|
):
|
||||||
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
||||||
|
|
||||||
if args_model_type is None:
|
if args_model_type is None:
|
||||||
@@ -156,7 +354,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
|
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
if model_type not in MODEL_CLASSES:
|
if model_type not in MODEL_CLASSES:
|
||||||
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
|
raise ValueError(
|
||||||
|
"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
|
||||||
|
)
|
||||||
|
|
||||||
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||||
|
|
||||||
@@ -166,9 +366,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
config_shortcut_names_or_path = model_shortcut_names_or_path
|
config_shortcut_names_or_path = model_shortcut_names_or_path
|
||||||
|
|
||||||
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
|
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
|
||||||
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1):
|
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
|
||||||
|
):
|
||||||
print("-" * 100)
|
print("-" * 100)
|
||||||
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name:
|
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
|
||||||
if not only_convert_finetuned_models:
|
if not only_convert_finetuned_models:
|
||||||
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
|
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
|
||||||
continue
|
continue
|
||||||
@@ -176,7 +377,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
elif only_convert_finetuned_models:
|
elif only_convert_finetuned_models:
|
||||||
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
|
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
|
||||||
continue
|
continue
|
||||||
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type))
|
print(
|
||||||
|
" Converting checkpoint {}/{}: {} - model_type {}".format(
|
||||||
|
i, len(aws_config_map), model_shortcut_name, model_type
|
||||||
|
)
|
||||||
|
)
|
||||||
print("-" * 100)
|
print("-" * 100)
|
||||||
|
|
||||||
if config_shortcut_name in aws_config_map:
|
if config_shortcut_name in aws_config_map:
|
||||||
@@ -190,13 +395,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
||||||
|
|
||||||
if os.path.isfile(model_shortcut_name):
|
if os.path.isfile(model_shortcut_name):
|
||||||
model_shortcut_name = 'converted_model'
|
model_shortcut_name = "converted_model"
|
||||||
|
|
||||||
convert_pt_checkpoint_to_tf(model_type=model_type,
|
convert_pt_checkpoint_to_tf(
|
||||||
|
model_type=model_type,
|
||||||
pytorch_checkpoint_path=model_file,
|
pytorch_checkpoint_path=model_file,
|
||||||
config_file=config_file,
|
config_file=config_file,
|
||||||
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
|
||||||
compare_with_pt_model=compare_with_pt_model)
|
compare_with_pt_model=compare_with_pt_model,
|
||||||
|
)
|
||||||
if remove_cached_files:
|
if remove_cached_files:
|
||||||
os.remove(config_file)
|
os.remove(config_file)
|
||||||
os.remove(model_file)
|
os.remove(model_file)
|
||||||
@@ -205,39 +412,47 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_dump_path",
|
parser.add_argument(
|
||||||
|
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required = True,
|
help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(
|
||||||
help = "Path to the output Tensorflow dump file.")
|
list(MODEL_CLASSES.keys())
|
||||||
parser.add_argument("--model_type",
|
),
|
||||||
default = None,
|
)
|
||||||
type = str,
|
parser.add_argument(
|
||||||
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys())))
|
"--pytorch_checkpoint_path",
|
||||||
parser.add_argument("--pytorch_checkpoint_path",
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
|
help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
|
||||||
"If not given, will download and convert all the checkpoints from AWS.")
|
"If not given, will download and convert all the checkpoints from AWS.",
|
||||||
parser.add_argument("--config_file",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_file",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="The config json file corresponding to the pre-trained model. \n"
|
help="The config json file corresponding to the pre-trained model. \n"
|
||||||
"This specifies the model architecture. If not given and "
|
"This specifies the model architecture. If not given and "
|
||||||
"--pytorch_checkpoint_path is not given or is a shortcut name"
|
"--pytorch_checkpoint_path is not given or is a shortcut name"
|
||||||
"use the configuration associated to the shortcut name on the AWS")
|
"use the configuration associated to the shortcut name on the AWS",
|
||||||
parser.add_argument("--compare_with_pt_model",
|
)
|
||||||
action='store_true',
|
parser.add_argument(
|
||||||
help = "Compare Tensorflow and PyTorch model predictions.")
|
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
|
||||||
parser.add_argument("--use_cached_models",
|
)
|
||||||
action='store_true',
|
parser.add_argument(
|
||||||
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
"--use_cached_models",
|
||||||
parser.add_argument("--remove_cached_files",
|
action="store_true",
|
||||||
action='store_true',
|
help="Use cached models if possible instead of updating to latest checkpoint versions.",
|
||||||
help = "Remove pytorch models after conversion (save memory when converting in batches).")
|
)
|
||||||
parser.add_argument("--only_convert_finetuned_models",
|
parser.add_argument(
|
||||||
action='store_true',
|
"--remove_cached_files",
|
||||||
help = "Only convert finetuned models.")
|
action="store_true",
|
||||||
|
help="Remove pytorch models after conversion (save memory when converting in batches).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# if args.pytorch_checkpoint_path is not None:
|
# if args.pytorch_checkpoint_path is not None:
|
||||||
@@ -248,11 +463,15 @@ if __name__ == "__main__":
|
|||||||
# compare_with_pt_model=args.compare_with_pt_model,
|
# compare_with_pt_model=args.compare_with_pt_model,
|
||||||
# use_cached_models=args.use_cached_models)
|
# use_cached_models=args.use_cached_models)
|
||||||
# else:
|
# else:
|
||||||
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
convert_all_pt_checkpoints_to_tf(
|
||||||
|
args.model_type.lower() if args.model_type is not None else None,
|
||||||
args.tf_dump_path,
|
args.tf_dump_path,
|
||||||
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
|
model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
|
||||||
|
if args.pytorch_checkpoint_path is not None
|
||||||
|
else None,
|
||||||
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||||
compare_with_pt_model=args.compare_with_pt_model,
|
compare_with_pt_model=args.compare_with_pt_model,
|
||||||
use_cached_models=args.use_cached_models,
|
use_cached_models=args.use_cached_models,
|
||||||
remove_cached_files=args.remove_cached_files,
|
remove_cached_files=args.remove_cached_files,
|
||||||
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
only_convert_finetuned_models=args.only_convert_finetuned_models,
|
||||||
|
)
|
||||||
|
|||||||
@@ -30,20 +30,27 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
|||||||
|
|
||||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
from transformers.modeling_bert import (BertConfig, BertEncoder,
|
from transformers.modeling_bert import (
|
||||||
BertIntermediate, BertLayer,
|
BertConfig,
|
||||||
BertModel, BertOutput,
|
BertEncoder,
|
||||||
|
BertIntermediate,
|
||||||
|
BertLayer,
|
||||||
|
BertModel,
|
||||||
|
BertOutput,
|
||||||
BertSelfAttention,
|
BertSelfAttention,
|
||||||
BertSelfOutput)
|
BertSelfOutput,
|
||||||
from transformers.modeling_roberta import (RobertaEmbeddings,
|
)
|
||||||
|
from transformers.modeling_roberta import (
|
||||||
|
RobertaEmbeddings,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaModel)
|
RobertaModel,
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||||
|
|
||||||
|
|
||||||
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
|
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
|
||||||
@@ -74,7 +81,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
# Embeddings
|
# Embeddings
|
||||||
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
|
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
|
||||||
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
|
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
|
||||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
|
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
|
||||||
|
model.roberta.embeddings.token_type_embeddings.weight
|
||||||
|
) # just zero them out b/c RoBERTa doesn't use them.
|
||||||
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
|
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
|
||||||
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
|
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
|
||||||
|
|
||||||
@@ -86,10 +95,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
### self attention
|
### self attention
|
||||||
self_attn: BertSelfAttention = layer.attention.self
|
self_attn: BertSelfAttention = layer.attention.self
|
||||||
assert (
|
assert (
|
||||||
roberta_layer.self_attn.k_proj.weight.data.shape == \
|
roberta_layer.self_attn.k_proj.weight.data.shape
|
||||||
roberta_layer.self_attn.q_proj.weight.data.shape == \
|
== roberta_layer.self_attn.q_proj.weight.data.shape
|
||||||
roberta_layer.self_attn.v_proj.weight.data.shape == \
|
== roberta_layer.self_attn.v_proj.weight.data.shape
|
||||||
torch.Size((config.hidden_size, config.hidden_size))
|
== torch.Size((config.hidden_size, config.hidden_size))
|
||||||
)
|
)
|
||||||
|
|
||||||
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
|
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
|
||||||
@@ -101,9 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
|
|
||||||
### self-attention output
|
### self-attention output
|
||||||
self_output: BertSelfOutput = layer.attention.output
|
self_output: BertSelfOutput = layer.attention.output
|
||||||
assert(
|
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||||
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
|
||||||
)
|
|
||||||
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
||||||
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
|
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
|
||||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||||
@@ -111,17 +118,13 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
|
|
||||||
### intermediate
|
### intermediate
|
||||||
intermediate: BertIntermediate = layer.intermediate
|
intermediate: BertIntermediate = layer.intermediate
|
||||||
assert(
|
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||||
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
|
||||||
)
|
|
||||||
intermediate.dense.weight = roberta_layer.fc1.weight
|
intermediate.dense.weight = roberta_layer.fc1.weight
|
||||||
intermediate.dense.bias = roberta_layer.fc1.bias
|
intermediate.dense.bias = roberta_layer.fc1.bias
|
||||||
|
|
||||||
### output
|
### output
|
||||||
bert_output: BertOutput = layer.output
|
bert_output: BertOutput = layer.output
|
||||||
assert(
|
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||||
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
|
||||||
)
|
|
||||||
bert_output.dense.weight = roberta_layer.fc2.weight
|
bert_output.dense.weight = roberta_layer.fc2.weight
|
||||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||||
@@ -129,10 +132,10 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
#### end of layer
|
#### end of layer
|
||||||
|
|
||||||
if classification_head:
|
if classification_head:
|
||||||
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
|
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
|
||||||
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
|
model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias
|
||||||
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
|
model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight
|
||||||
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
|
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
|
||||||
else:
|
else:
|
||||||
# LM Head
|
# LM Head
|
||||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
||||||
@@ -147,17 +150,14 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
|
|
||||||
our_output = model(input_ids)[0]
|
our_output = model(input_ids)[0]
|
||||||
if classification_head:
|
if classification_head:
|
||||||
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
|
their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids))
|
||||||
else:
|
else:
|
||||||
their_output = roberta.model(input_ids)[0]
|
their_output = roberta.model(input_ids)[0]
|
||||||
print(our_output.shape, their_output.shape)
|
print(our_output.shape, their_output.shape)
|
||||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
||||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||||
print(
|
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
||||||
"Do both models output the same tensors?",
|
|
||||||
"🔥" if success else "💩"
|
|
||||||
)
|
|
||||||
if not success:
|
if not success:
|
||||||
raise Exception("Something went wRoNg")
|
raise Exception("Something went wRoNg")
|
||||||
|
|
||||||
@@ -169,23 +169,16 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--roberta_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path the official PyTorch dump.")
|
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--classification_head", action="store_true", help="Whether to convert a final classification head."
|
||||||
required = True,
|
)
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
parser.add_argument("--classification_head",
|
|
||||||
action = "store_true",
|
|
||||||
help = "Whether to convert a final classification head.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_roberta_checkpoint_to_pytorch(
|
convert_roberta_checkpoint_to_pytorch(
|
||||||
args.roberta_checkpoint_path,
|
args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
|
||||||
args.pytorch_dump_folder_path,
|
|
||||||
args.classification_head
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,8 +24,10 @@ import torch
|
|||||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = T5Config.from_json_file(config_file)
|
config = T5Config.from_json_file(config_file)
|
||||||
@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
"--config_file",
|
||||||
parser.add_argument("--config_file",
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The config json file corresponding to the pre-trained T5 model. \n"
|
help="The config json file corresponding to the pre-trained T5 model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
parser.add_argument("--pytorch_dump_path",
|
)
|
||||||
default = None,
|
parser.add_argument(
|
||||||
type = str,
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
required = True,
|
)
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||||
args.config_file,
|
|
||||||
args.pytorch_dump_path)
|
|
||||||
|
|||||||
@@ -26,9 +26,8 @@ import torch
|
|||||||
import transformers.tokenization_transfo_xl as data_utils
|
import transformers.tokenization_transfo_xl as data_utils
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
|
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||||
load_tf_weights_in_transfo_xl)
|
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||||
from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
|
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
@@ -36,32 +35,33 @@ else:
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# We do this to be able to load python 2 datasets pickles
|
# We do this to be able to load python 2 datasets pickles
|
||||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||||
data_utils.Corpus = data_utils.TransfoXLCorpus
|
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||||
sys.modules['data_utils'] = data_utils
|
sys.modules["data_utils"] = data_utils
|
||||||
sys.modules['vocabulary'] = data_utils
|
sys.modules["vocabulary"] = data_utils
|
||||||
|
|
||||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|
||||||
transfo_xl_config_file,
|
def convert_transfo_xl_checkpoint_to_pytorch(
|
||||||
pytorch_dump_folder_path,
|
tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
|
||||||
transfo_xl_dataset_file):
|
):
|
||||||
if transfo_xl_dataset_file:
|
if transfo_xl_dataset_file:
|
||||||
# Convert a pre-processed corpus (see original TensorFlow repo)
|
# Convert a pre-processed corpus (see original TensorFlow repo)
|
||||||
with open(transfo_xl_dataset_file, "rb") as fp:
|
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||||
corpus = pickle.load(fp, encoding="latin1")
|
corpus = pickle.load(fp, encoding="latin1")
|
||||||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
|
pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
|
||||||
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
||||||
corpus_vocab_dict = corpus.vocab.__dict__
|
corpus_vocab_dict = corpus.vocab.__dict__
|
||||||
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
|
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
|
||||||
|
|
||||||
corpus_dict_no_vocab = corpus.__dict__
|
corpus_dict_no_vocab = corpus.__dict__
|
||||||
corpus_dict_no_vocab.pop('vocab', None)
|
corpus_dict_no_vocab.pop("vocab", None)
|
||||||
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
|
pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
|
||||||
print("Save dataset to {}".format(pytorch_dataset_dump_path))
|
print("Save dataset to {}".format(pytorch_dataset_dump_path))
|
||||||
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
||||||
|
|
||||||
@@ -92,26 +92,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
|
help="Path to the folder to store the PyTorch model or dataset/vocab.",
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tf_checkpoint_path",
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help = "An optional path to a TensorFlow checkpoint path to be converted.")
|
help="An optional path to a TensorFlow checkpoint path to be converted.",
|
||||||
parser.add_argument("--transfo_xl_config_file",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transfo_xl_config_file",
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help="An optional config json file corresponding to the pre-trained BERT model. \n"
|
help="An optional config json file corresponding to the pre-trained BERT model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
parser.add_argument("--transfo_xl_dataset_file",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transfo_xl_dataset_file",
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help = "An optional dataset file to be converted in a vocabulary.")
|
help="An optional dataset file to be converted in a vocabulary.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_transfo_xl_checkpoint_to_pytorch(
|
||||||
|
args.tf_checkpoint_path,
|
||||||
args.transfo_xl_config_file,
|
args.transfo_xl_config_file,
|
||||||
args.pytorch_dump_folder_path,
|
args.pytorch_dump_folder_path,
|
||||||
args.transfo_xl_dataset_file)
|
args.transfo_xl_dataset_file,
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,32 +27,34 @@ from transformers import CONFIG_NAME, WEIGHTS_NAME
|
|||||||
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||||
# Load checkpoint
|
# Load checkpoint
|
||||||
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
|
chkpt = torch.load(xlm_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
state_dict = chkpt['model']
|
state_dict = chkpt["model"]
|
||||||
|
|
||||||
# We have the base model one level deeper than the original XLM repository
|
# We have the base model one level deeper than the original XLM repository
|
||||||
two_levels_state_dict = {}
|
two_levels_state_dict = {}
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if 'pred_layer' in k:
|
if "pred_layer" in k:
|
||||||
two_levels_state_dict[k] = v
|
two_levels_state_dict[k] = v
|
||||||
else:
|
else:
|
||||||
two_levels_state_dict['transformer.' + k] = v
|
two_levels_state_dict["transformer." + k] = v
|
||||||
|
|
||||||
config = chkpt['params']
|
config = chkpt["params"]
|
||||||
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
|
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
|
||||||
|
|
||||||
vocab = chkpt['dico_word2id']
|
vocab = chkpt["dico_word2id"]
|
||||||
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items())
|
vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
|
pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
torch.save(two_levels_state_dict, pytorch_weights_dump_path)
|
torch.save(two_levels_state_dict, pytorch_weights_dump_path)
|
||||||
@@ -69,15 +71,11 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--xlm_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path the official PyTorch dump.")
|
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
)
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
required = True,
|
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
|
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
|
||||||
|
|||||||
@@ -22,11 +22,15 @@ import os
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
from transformers import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetLMHeadModel, XLNetForQuestionAnswering,
|
XLNetLMHeadModel,
|
||||||
|
XLNetForQuestionAnswering,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
load_tf_weights_in_xlnet)
|
load_tf_weights_in_xlnet,
|
||||||
|
)
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
"cola": 2,
|
"cola": 2,
|
||||||
@@ -41,9 +45,13 @@ GLUE_TASKS_NUM_LABELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
|
|
||||||
|
def convert_xlnet_checkpoint_to_pytorch(
|
||||||
|
tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None
|
||||||
|
):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = XLNetConfig.from_json_file(bert_config_file)
|
config = XLNetConfig.from_json_file(bert_config_file)
|
||||||
|
|
||||||
@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
|
|||||||
config.finetuning_task = finetuning_task
|
config.finetuning_task = finetuning_task
|
||||||
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
|
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
|
||||||
model = XLNetForSequenceClassification(config)
|
model = XLNetForSequenceClassification(config)
|
||||||
elif 'squad' in finetuning_task:
|
elif "squad" in finetuning_task:
|
||||||
config.finetuning_task = finetuning_task
|
config.finetuning_task = finetuning_task
|
||||||
model = XLNetForQuestionAnswering(config)
|
model = XLNetForQuestionAnswering(config)
|
||||||
else:
|
else:
|
||||||
@@ -75,30 +83,33 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
parser.add_argument(
|
||||||
default = None,
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
type = str,
|
)
|
||||||
required = True,
|
parser.add_argument(
|
||||||
help = "Path to the TensorFlow checkpoint path.")
|
"--xlnet_config_file",
|
||||||
parser.add_argument("--xlnet_config_file",
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The config json file corresponding to the pre-trained XLNet model. \n"
|
help="The config json file corresponding to the pre-trained XLNet model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.",
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
|
help="Path to the folder to store the PyTorch model or dataset/vocab.",
|
||||||
parser.add_argument("--finetuning_task",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--finetuning_task",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
|
help="Name of a task on which the XLNet TensorFloaw model was fine-tuned",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_xlnet_checkpoint_to_pytorch(
|
||||||
args.xlnet_config_file,
|
args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task
|
||||||
args.pytorch_dump_folder_path,
|
)
|
||||||
args.finetuning_task)
|
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor
|
from .processors import (
|
||||||
|
InputExample,
|
||||||
|
InputFeatures,
|
||||||
|
DataProcessor,
|
||||||
|
SquadFeatures,
|
||||||
|
SingleSentenceClassificationProcessor,
|
||||||
|
)
|
||||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||||
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
|
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
|
||||||
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
||||||
|
|
||||||
from .metrics import is_sklearn_available
|
from .metrics import is_sklearn_available
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
||||||
|
|||||||
@@ -23,20 +23,22 @@ logger = logging.getLogger(__name__)
|
|||||||
try:
|
try:
|
||||||
from scipy.stats import pearsonr, spearmanr
|
from scipy.stats import pearsonr, spearmanr
|
||||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||||
|
|
||||||
_has_sklearn = True
|
_has_sklearn = True
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
||||||
_has_sklearn = False
|
_has_sklearn = False
|
||||||
|
|
||||||
|
|
||||||
def is_sklearn_available():
|
def is_sklearn_available():
|
||||||
return _has_sklearn
|
return _has_sklearn
|
||||||
|
|
||||||
|
|
||||||
if _has_sklearn:
|
if _has_sklearn:
|
||||||
|
|
||||||
def simple_accuracy(preds, labels):
|
def simple_accuracy(preds, labels):
|
||||||
return (preds == labels).mean()
|
return (preds == labels).mean()
|
||||||
|
|
||||||
|
|
||||||
def acc_and_f1(preds, labels):
|
def acc_and_f1(preds, labels):
|
||||||
acc = simple_accuracy(preds, labels)
|
acc = simple_accuracy(preds, labels)
|
||||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||||
@@ -46,7 +48,6 @@ if _has_sklearn:
|
|||||||
"acc_and_f1": (acc + f1) / 2,
|
"acc_and_f1": (acc + f1) / 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def pearson_and_spearman(preds, labels):
|
def pearson_and_spearman(preds, labels):
|
||||||
pearson_corr = pearsonr(preds, labels)[0]
|
pearson_corr = pearsonr(preds, labels)[0]
|
||||||
spearman_corr = spearmanr(preds, labels)[0]
|
spearman_corr = spearmanr(preds, labels)[0]
|
||||||
@@ -56,7 +57,6 @@ if _has_sklearn:
|
|||||||
"corr": (pearson_corr + spearman_corr) / 2,
|
"corr": (pearson_corr + spearman_corr) / 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def glue_compute_metrics(task_name, preds, labels):
|
def glue_compute_metrics(task_name, preds, labels):
|
||||||
assert len(preds) == len(labels)
|
assert len(preds) == len(labels)
|
||||||
if task_name == "cola":
|
if task_name == "cola":
|
||||||
@@ -82,7 +82,6 @@ if _has_sklearn:
|
|||||||
else:
|
else:
|
||||||
raise KeyError(task_name)
|
raise KeyError(task_name)
|
||||||
|
|
||||||
|
|
||||||
def xnli_compute_metrics(task_name, preds, labels):
|
def xnli_compute_metrics(task_name, preds, labels):
|
||||||
assert len(preds) == len(labels)
|
assert len(preds) == len(labels)
|
||||||
if task_name == "xnli":
|
if task_name == "xnli":
|
||||||
|
|||||||
@@ -24,19 +24,21 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def normalize_answer(s):
|
def normalize_answer(s):
|
||||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||||
|
|
||||||
def remove_articles(text):
|
def remove_articles(text):
|
||||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||||
return re.sub(regex, ' ', text)
|
return re.sub(regex, " ", text)
|
||||||
|
|
||||||
def white_space_fix(text):
|
def white_space_fix(text):
|
||||||
return ' '.join(text.split())
|
return " ".join(text.split())
|
||||||
|
|
||||||
def remove_punc(text):
|
def remove_punc(text):
|
||||||
exclude = set(string.punctuation)
|
exclude = set(string.punctuation)
|
||||||
return ''.join(ch for ch in text if ch not in exclude)
|
return "".join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
def lower(text):
|
def lower(text):
|
||||||
return text.lower()
|
return text.lower()
|
||||||
|
|
||||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
|
||||||
@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds):
|
|||||||
|
|
||||||
for example in examples:
|
for example in examples:
|
||||||
qas_id = example.qas_id
|
qas_id = example.qas_id
|
||||||
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])]
|
gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
|
||||||
|
|
||||||
if not gold_answers:
|
if not gold_answers:
|
||||||
# For unanswerable questions, only correct answer is empty string
|
# For unanswerable questions, only correct answer is empty string
|
||||||
gold_answers = ['']
|
gold_answers = [""]
|
||||||
|
|
||||||
if qas_id not in preds:
|
if qas_id not in preds:
|
||||||
print('Missing prediction for %s' % qas_id)
|
print("Missing prediction for %s" % qas_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prediction = preds[qas_id]
|
prediction = preds[qas_id]
|
||||||
@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
|||||||
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||||
if not qid_list:
|
if not qid_list:
|
||||||
total = len(exact_scores)
|
total = len(exact_scores)
|
||||||
return collections.OrderedDict([
|
return collections.OrderedDict(
|
||||||
('exact', 100.0 * sum(exact_scores.values()) / total),
|
[
|
||||||
('f1', 100.0 * sum(f1_scores.values()) / total),
|
("exact", 100.0 * sum(exact_scores.values()) / total),
|
||||||
('total', total),
|
("f1", 100.0 * sum(f1_scores.values()) / total),
|
||||||
])
|
("total", total),
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
total = len(qid_list)
|
total = len(qid_list)
|
||||||
return collections.OrderedDict([
|
return collections.OrderedDict(
|
||||||
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
[
|
||||||
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||||
('total', total),
|
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||||
])
|
("total", total),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_eval(main_eval, new_eval, prefix):
|
def merge_eval(main_eval, new_eval, prefix):
|
||||||
for k in new_eval:
|
for k in new_eval:
|
||||||
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
main_eval["%s_%s" % (prefix, k)] = new_eval[k]
|
||||||
|
|
||||||
|
|
||||||
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||||
@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||||
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
|
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||||
preds, exact_raw, na_probs, qid_to_has_ans)
|
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||||
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
|
main_eval["best_exact"] = best_exact
|
||||||
preds, f1_raw, na_probs, qid_to_has_ans)
|
main_eval["best_exact_thresh"] = exact_thresh
|
||||||
main_eval['best_exact'] = best_exact
|
main_eval["best_f1"] = best_f1
|
||||||
main_eval['best_exact_thresh'] = exact_thresh
|
main_eval["best_f1_thresh"] = f1_thresh
|
||||||
main_eval['best_f1'] = best_f1
|
main_eval["has_ans_exact"] = has_ans_exact
|
||||||
main_eval['best_f1_thresh'] = f1_thresh
|
main_eval["has_ans_f1"] = has_ans_f1
|
||||||
main_eval['has_ans_exact'] = has_ans_exact
|
|
||||||
main_eval['has_ans_f1'] = has_ans_f1
|
|
||||||
|
|
||||||
|
|
||||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||||
@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
|
|||||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||||
|
|
||||||
main_eval['best_exact'] = best_exact
|
main_eval["best_exact"] = best_exact
|
||||||
main_eval['best_exact_thresh'] = exact_thresh
|
main_eval["best_exact_thresh"] = exact_thresh
|
||||||
main_eval['best_f1'] = best_f1
|
main_eval["best_f1"] = best_f1
|
||||||
main_eval['best_f1_thresh'] = f1_thresh
|
main_eval["best_f1_thresh"] = f1_thresh
|
||||||
|
|
||||||
|
|
||||||
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
||||||
@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_
|
|||||||
|
|
||||||
exact, f1 = get_raw_scores(examples, preds)
|
exact, f1 = get_raw_scores(examples, preds)
|
||||||
|
|
||||||
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
exact_threshold = apply_no_ans_threshold(
|
||||||
|
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
|
||||||
|
)
|
||||||
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
||||||
|
|
||||||
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
||||||
|
|
||||||
if has_answer_qids:
|
if has_answer_qids:
|
||||||
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
||||||
merge_eval(evaluation, has_ans_eval, 'HasAns')
|
merge_eval(evaluation, has_ans_eval, "HasAns")
|
||||||
|
|
||||||
if no_answer_qids:
|
if no_answer_qids:
|
||||||
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
||||||
merge_eval(evaluation, no_ans_eval, 'NoAns')
|
merge_eval(evaluation, no_ans_eval, "NoAns")
|
||||||
|
|
||||||
if no_answer_probs:
|
if no_answer_probs:
|
||||||
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
||||||
@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
|||||||
start_position = tok_text.find(pred_text)
|
start_position = tok_text.find(pred_text)
|
||||||
if start_position == -1:
|
if start_position == -1:
|
||||||
if verbose_logging:
|
if verbose_logging:
|
||||||
logger.info(
|
logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||||
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
|
||||||
return orig_text
|
return orig_text
|
||||||
end_position = start_position + len(pred_text) - 1
|
end_position = start_position + len(pred_text) - 1
|
||||||
|
|
||||||
@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
|||||||
|
|
||||||
if len(orig_ns_text) != len(tok_ns_text):
|
if len(orig_ns_text) != len(tok_ns_text):
|
||||||
if verbose_logging:
|
if verbose_logging:
|
||||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
|
||||||
orig_ns_text, tok_ns_text)
|
|
||||||
return orig_text
|
return orig_text
|
||||||
|
|
||||||
# We then project the characters in `pred_text` back to `orig_text` using
|
# We then project the characters in `pred_text` back to `orig_text` using
|
||||||
@@ -393,8 +397,8 @@ def compute_predictions_logits(
|
|||||||
unique_id_to_result[result.unique_id] = result
|
unique_id_to_result[result.unique_id] = result
|
||||||
|
|
||||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"PrelimPrediction",
|
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
||||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
)
|
||||||
|
|
||||||
all_predictions = collections.OrderedDict()
|
all_predictions = collections.OrderedDict()
|
||||||
all_nbest_json = collections.OrderedDict()
|
all_nbest_json = collections.OrderedDict()
|
||||||
@@ -447,7 +451,9 @@ def compute_predictions_logits(
|
|||||||
start_index=start_index,
|
start_index=start_index,
|
||||||
end_index=end_index,
|
end_index=end_index,
|
||||||
start_logit=result.start_logits[start_index],
|
start_logit=result.start_logits[start_index],
|
||||||
end_logit=result.end_logits[end_index]))
|
end_logit=result.end_logits[end_index],
|
||||||
|
)
|
||||||
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
prelim_predictions.append(
|
prelim_predictions.append(
|
||||||
_PrelimPrediction(
|
_PrelimPrediction(
|
||||||
@@ -455,14 +461,14 @@ def compute_predictions_logits(
|
|||||||
start_index=0,
|
start_index=0,
|
||||||
end_index=0,
|
end_index=0,
|
||||||
start_logit=null_start_logit,
|
start_logit=null_start_logit,
|
||||||
end_logit=null_end_logit))
|
end_logit=null_end_logit,
|
||||||
prelim_predictions = sorted(
|
)
|
||||||
prelim_predictions,
|
)
|
||||||
key=lambda x: (x.start_logit + x.end_logit),
|
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
||||||
|
)
|
||||||
|
|
||||||
seen_predictions = {}
|
seen_predictions = {}
|
||||||
nbest = []
|
nbest = []
|
||||||
@@ -498,31 +504,21 @@ def compute_predictions_logits(
|
|||||||
final_text = ""
|
final_text = ""
|
||||||
seen_predictions[final_text] = True
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
||||||
_NbestPrediction(
|
|
||||||
text=final_text,
|
|
||||||
start_logit=pred.start_logit,
|
|
||||||
end_logit=pred.end_logit))
|
|
||||||
# if we didn't include the empty option in the n-best, include it
|
# if we didn't include the empty option in the n-best, include it
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
if "" not in seen_predictions:
|
if "" not in seen_predictions:
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
||||||
_NbestPrediction(
|
|
||||||
text="",
|
|
||||||
start_logit=null_start_logit,
|
|
||||||
end_logit=null_end_logit))
|
|
||||||
|
|
||||||
# In very rare edge cases we could only have single null prediction.
|
# In very rare edge cases we could only have single null prediction.
|
||||||
# So we just create a nonce prediction in this case to avoid failure.
|
# So we just create a nonce prediction in this case to avoid failure.
|
||||||
if len(nbest) == 1:
|
if len(nbest) == 1:
|
||||||
nbest.insert(0,
|
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
|
||||||
|
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if not nbest:
|
if not nbest:
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
|
||||||
|
|
||||||
assert len(nbest) >= 1
|
assert len(nbest) >= 1
|
||||||
|
|
||||||
@@ -551,8 +547,7 @@ def compute_predictions_logits(
|
|||||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||||
else:
|
else:
|
||||||
# predict "" iff the null score - the score of best non-null > threshold
|
# predict "" iff the null score - the score of best non-null > threshold
|
||||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
||||||
best_non_null_entry.end_logit)
|
|
||||||
scores_diff_json[example.qas_id] = score_diff
|
scores_diff_json[example.qas_id] = score_diff
|
||||||
if score_diff > null_score_diff_threshold:
|
if score_diff > null_score_diff_threshold:
|
||||||
all_predictions[example.qas_id] = ""
|
all_predictions[example.qas_id] = ""
|
||||||
@@ -586,7 +581,7 @@ def compute_predictions_log_probs(
|
|||||||
end_n_top,
|
end_n_top,
|
||||||
version_2_with_negative,
|
version_2_with_negative,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
verbose_logging
|
verbose_logging,
|
||||||
):
|
):
|
||||||
""" XLNet write prediction logic (more complex than Bert's).
|
""" XLNet write prediction logic (more complex than Bert's).
|
||||||
Write final predictions to the json file and log-odds of null if needed.
|
Write final predictions to the json file and log-odds of null if needed.
|
||||||
@@ -594,12 +589,12 @@ def compute_predictions_log_probs(
|
|||||||
Requires utils_squad_evaluate.py
|
Requires utils_squad_evaluate.py
|
||||||
"""
|
"""
|
||||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"PrelimPrediction",
|
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
||||||
["feature_index", "start_index", "end_index",
|
)
|
||||||
"start_log_prob", "end_log_prob"])
|
|
||||||
|
|
||||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Writing predictions to: %s", output_prediction_file)
|
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||||
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
@@ -663,12 +658,13 @@ def compute_predictions_log_probs(
|
|||||||
start_index=start_index,
|
start_index=start_index,
|
||||||
end_index=end_index,
|
end_index=end_index,
|
||||||
start_log_prob=start_log_prob,
|
start_log_prob=start_log_prob,
|
||||||
end_log_prob=end_log_prob))
|
end_log_prob=end_log_prob,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prelim_predictions = sorted(
|
prelim_predictions = sorted(
|
||||||
prelim_predictions,
|
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
||||||
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
)
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
seen_predictions = {}
|
seen_predictions = {}
|
||||||
nbest = []
|
nbest = []
|
||||||
@@ -704,8 +700,7 @@ def compute_predictions_log_probs(
|
|||||||
else:
|
else:
|
||||||
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
||||||
|
|
||||||
final_text = get_final_text(tok_text, orig_text, do_lower_case,
|
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
||||||
verbose_logging)
|
|
||||||
|
|
||||||
if final_text in seen_predictions:
|
if final_text in seen_predictions:
|
||||||
continue
|
continue
|
||||||
@@ -713,17 +708,13 @@ def compute_predictions_log_probs(
|
|||||||
seen_predictions[final_text] = True
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(
|
||||||
_NbestPrediction(
|
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
||||||
text=final_text,
|
)
|
||||||
start_log_prob=pred.start_log_prob,
|
|
||||||
end_log_prob=pred.end_log_prob))
|
|
||||||
|
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if not nbest:
|
if not nbest:
|
||||||
nbest.append(
|
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
||||||
_NbestPrediction(text="", start_log_prob=-1e6,
|
|
||||||
end_log_prob=-1e6))
|
|
||||||
|
|
||||||
total_scores = []
|
total_scores = []
|
||||||
best_non_null_entry = None
|
best_non_null_entry = None
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ if is_tf_available():
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def glue_convert_examples_to_features(examples, tokenizer,
|
def glue_convert_examples_to_features(
|
||||||
|
examples,
|
||||||
|
tokenizer,
|
||||||
max_length=512,
|
max_length=512,
|
||||||
task=None,
|
task=None,
|
||||||
label_list=None,
|
label_list=None,
|
||||||
@@ -35,7 +37,8 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||||||
pad_on_left=False,
|
pad_on_left=False,
|
||||||
pad_token=0,
|
pad_token=0,
|
||||||
pad_token_segment_id=0,
|
pad_token_segment_id=0,
|
||||||
mask_padding_with_zero=True):
|
mask_padding_with_zero=True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Loads a data file into a list of ``InputFeatures``
|
Loads a data file into a list of ``InputFeatures``
|
||||||
|
|
||||||
@@ -82,12 +85,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||||||
example = processor.get_example_from_tensor_dict(example)
|
example = processor.get_example_from_tensor_dict(example)
|
||||||
example = processor.tfds_map(example)
|
example = processor.tfds_map(example)
|
||||||
|
|
||||||
inputs = tokenizer.encode_plus(
|
inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
|
||||||
example.text_a,
|
|
||||||
example.text_b,
|
|
||||||
add_special_tokens=True,
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||||
|
|
||||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||||
@@ -106,8 +104,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||||||
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
||||||
|
|
||||||
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
|
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
|
||||||
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
|
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
|
||||||
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
|
len(attention_mask), max_length
|
||||||
|
)
|
||||||
|
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
|
||||||
|
len(token_type_ids), max_length
|
||||||
|
)
|
||||||
|
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
label = label_map[example.label]
|
label = label_map[example.label]
|
||||||
@@ -125,28 +127,36 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||||||
logger.info("label: %s (id = %d)" % (example.label, label))
|
logger.info("label: %s (id = %d)" % (example.label, label))
|
||||||
|
|
||||||
features.append(
|
features.append(
|
||||||
InputFeatures(input_ids=input_ids,
|
InputFeatures(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
|
||||||
token_type_ids=token_type_ids,
|
)
|
||||||
label=label))
|
)
|
||||||
|
|
||||||
if is_tf_available() and is_tf_dataset:
|
if is_tf_available() and is_tf_dataset:
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for ex in features:
|
||||||
yield ({'input_ids': ex.input_ids,
|
yield (
|
||||||
'attention_mask': ex.attention_mask,
|
{
|
||||||
'token_type_ids': ex.token_type_ids},
|
"input_ids": ex.input_ids,
|
||||||
ex.label)
|
"attention_mask": ex.attention_mask,
|
||||||
|
"token_type_ids": ex.token_type_ids,
|
||||||
|
},
|
||||||
|
ex.label,
|
||||||
|
)
|
||||||
|
|
||||||
return tf.data.Dataset.from_generator(gen,
|
return tf.data.Dataset.from_generator(
|
||||||
({'input_ids': tf.int32,
|
gen,
|
||||||
'attention_mask': tf.int32,
|
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
|
||||||
'token_type_ids': tf.int32},
|
(
|
||||||
tf.int64),
|
{
|
||||||
({'input_ids': tf.TensorShape([None]),
|
"input_ids": tf.TensorShape([None]),
|
||||||
'attention_mask': tf.TensorShape([None]),
|
"attention_mask": tf.TensorShape([None]),
|
||||||
'token_type_ids': tf.TensorShape([None])},
|
"token_type_ids": tf.TensorShape([None]),
|
||||||
tf.TensorShape([])))
|
},
|
||||||
|
tf.TensorShape([]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@@ -156,21 +166,21 @@ class MrpcProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -186,8 +196,7 @@ class MrpcProcessor(DataProcessor):
|
|||||||
text_a = line[3]
|
text_a = line[3]
|
||||||
text_b = line[4]
|
text_b = line[4]
|
||||||
label = line[0]
|
label = line[0]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -196,21 +205,20 @@ class MnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['premise'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
tensor_dict["premise"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["hypothesis"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
|
|
||||||
"dev_matched")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -226,8 +234,7 @@ class MnliProcessor(DataProcessor):
|
|||||||
text_a = line[8]
|
text_a = line[8]
|
||||||
text_b = line[9]
|
text_b = line[9]
|
||||||
label = line[-1]
|
label = line[-1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -236,9 +243,7 @@ class MnliMismatchedProcessor(MnliProcessor):
|
|||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
|
|
||||||
"dev_matched")
|
|
||||||
|
|
||||||
|
|
||||||
class ColaProcessor(DataProcessor):
|
class ColaProcessor(DataProcessor):
|
||||||
@@ -246,20 +251,20 @@ class ColaProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
|
tensor_dict["sentence"].numpy().decode("utf-8"),
|
||||||
None,
|
None,
|
||||||
str(tensor_dict['label'].numpy()))
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -272,8 +277,7 @@ class ColaProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, i)
|
guid = "%s-%s" % (set_type, i)
|
||||||
text_a = line[3]
|
text_a = line[3]
|
||||||
label = line[1]
|
label = line[1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -282,20 +286,20 @@ class Sst2Processor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
|
tensor_dict["sentence"].numpy().decode("utf-8"),
|
||||||
None,
|
None,
|
||||||
str(tensor_dict['label'].numpy()))
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -310,8 +314,7 @@ class Sst2Processor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, i)
|
guid = "%s-%s" % (set_type, i)
|
||||||
text_a = line[0]
|
text_a = line[0]
|
||||||
label = line[1]
|
label = line[1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -320,20 +323,20 @@ class StsbProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -349,8 +352,7 @@ class StsbProcessor(DataProcessor):
|
|||||||
text_a = line[7]
|
text_a = line[7]
|
||||||
text_b = line[8]
|
text_b = line[8]
|
||||||
label = line[-1]
|
label = line[-1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -359,20 +361,20 @@ class QqpProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['question1'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['question2'].numpy().decode('utf-8'),
|
tensor_dict["question1"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["question2"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -391,8 +393,7 @@ class QqpProcessor(DataProcessor):
|
|||||||
label = line[5]
|
label = line[5]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
continue
|
continue
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -401,21 +402,20 @@ class QnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['question'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
tensor_dict["question"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["sentence"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
|
||||||
"dev_matched")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -431,8 +431,7 @@ class QnliProcessor(DataProcessor):
|
|||||||
text_a = line[1]
|
text_a = line[1]
|
||||||
text_b = line[2]
|
text_b = line[2]
|
||||||
label = line[-1]
|
label = line[-1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -441,20 +440,20 @@ class RteProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -470,8 +469,7 @@ class RteProcessor(DataProcessor):
|
|||||||
text_a = line[1]
|
text_a = line[1]
|
||||||
text_b = line[2]
|
text_b = line[2]
|
||||||
label = line[-1]
|
label = line[-1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
@@ -480,20 +478,20 @@ class WnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_example_from_tensor_dict(self, tensor_dict):
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return InputExample(tensor_dict['idx'].numpy(),
|
return InputExample(
|
||||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
tensor_dict["idx"].numpy(),
|
||||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
||||||
str(tensor_dict['label'].numpy()))
|
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
||||||
|
str(tensor_dict["label"].numpy()),
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_examples(self, data_dir):
|
def get_train_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@@ -509,10 +507,10 @@ class WnliProcessor(DataProcessor):
|
|||||||
text_a = line[1]
|
text_a = line[1]
|
||||||
text_b = line[2]
|
text_b = line[2]
|
||||||
label = line[-1]
|
label = line[-1]
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
glue_tasks_num_labels = {
|
glue_tasks_num_labels = {
|
||||||
"cola": 2,
|
"cola": 2,
|
||||||
"mnli": 3,
|
"mnli": 3,
|
||||||
|
|||||||
@@ -82,8 +82,8 @@ def _is_whitespace(c):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def squad_convert_example_to_features(example, max_seq_length,
|
|
||||||
doc_stride, max_query_length, is_training):
|
def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
|
||||||
features = []
|
features = []
|
||||||
if is_training and not example.is_impossible:
|
if is_training and not example.is_impossible:
|
||||||
# Get start and end position
|
# Get start and end position
|
||||||
@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length,
|
|||||||
spans = []
|
spans = []
|
||||||
|
|
||||||
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
|
sequence_added_tokens = (
|
||||||
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
|
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
|
||||||
|
if "roberta" in str(type(tokenizer))
|
||||||
|
else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||||
|
)
|
||||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||||
|
|
||||||
span_doc_tokens = all_doc_tokens
|
span_doc_tokens = all_doc_tokens
|
||||||
@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length,
|
|||||||
return_overflowing_tokens=True,
|
return_overflowing_tokens=True,
|
||||||
pad_to_max_length=True,
|
pad_to_max_length=True,
|
||||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||||
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
|
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
|
||||||
)
|
)
|
||||||
|
|
||||||
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride,
|
paragraph_len = min(
|
||||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
len(all_doc_tokens) - len(spans) * doc_stride,
|
||||||
|
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
if tokenizer.pad_token_id in encoded_dict['input_ids']:
|
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
||||||
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
||||||
else:
|
else:
|
||||||
non_padded_ids = encoded_dict['input_ids']
|
non_padded_ids = encoded_dict["input_ids"]
|
||||||
|
|
||||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||||
|
|
||||||
@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length,
|
|||||||
for doc_span_index in range(len(spans)):
|
for doc_span_index in range(len(spans)):
|
||||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||||
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][
|
index = (
|
||||||
"truncated_query_with_special_tokens_length"] + j
|
j
|
||||||
|
if tokenizer.padding_side == "left"
|
||||||
|
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||||
|
)
|
||||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||||
|
|
||||||
for span in spans:
|
for span in spans:
|
||||||
# Identify the position of the CLS token
|
# Identify the position of the CLS token
|
||||||
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
|
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
||||||
|
|
||||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||||
p_mask = np.array(span['token_type_ids'])
|
p_mask = np.array(span["token_type_ids"])
|
||||||
|
|
||||||
p_mask = np.minimum(p_mask, 1)
|
p_mask = np.minimum(p_mask, 1)
|
||||||
|
|
||||||
@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length,
|
|||||||
start_position = tok_start_position - doc_start + doc_offset
|
start_position = tok_start_position - doc_start + doc_offset
|
||||||
end_position = tok_end_position - doc_start + doc_offset
|
end_position = tok_end_position - doc_start + doc_offset
|
||||||
|
|
||||||
features.append(SquadFeatures(
|
features.append(
|
||||||
span['input_ids'],
|
SquadFeatures(
|
||||||
span['attention_mask'],
|
span["input_ids"],
|
||||||
span['token_type_ids'],
|
span["attention_mask"],
|
||||||
|
span["token_type_ids"],
|
||||||
cls_index,
|
cls_index,
|
||||||
p_mask.tolist(),
|
p_mask.tolist(),
|
||||||
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
||||||
unique_id=0,
|
unique_id=0,
|
||||||
paragraph_len=span['paragraph_len'],
|
paragraph_len=span["paragraph_len"],
|
||||||
token_is_max_context=span["token_is_max_context"],
|
token_is_max_context=span["token_is_max_context"],
|
||||||
tokens=span["tokens"],
|
tokens=span["tokens"],
|
||||||
token_to_orig_map=span["token_to_orig_map"],
|
token_to_orig_map=span["token_to_orig_map"],
|
||||||
|
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position
|
end_position=end_position,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def squad_convert_example_to_features_init(tokenizer_for_convert):
|
def squad_convert_example_to_features_init(tokenizer_for_convert):
|
||||||
global tokenizer
|
global tokenizer
|
||||||
tokenizer = tokenizer_for_convert
|
tokenizer = tokenizer_for_convert
|
||||||
|
|
||||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|
||||||
doc_stride, max_query_length, is_training,
|
def squad_convert_examples_to_features(
|
||||||
return_dataset=False, threads=1):
|
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||||
@@ -283,13 +294,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
features = []
|
features = []
|
||||||
threads = min(threads, cpu_count())
|
threads = min(threads, cpu_count())
|
||||||
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
||||||
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length,
|
annotate_ = partial(
|
||||||
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training)
|
squad_convert_example_to_features,
|
||||||
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features'))
|
max_seq_length=max_seq_length,
|
||||||
|
doc_stride=doc_stride,
|
||||||
|
max_query_length=max_query_length,
|
||||||
|
is_training=is_training,
|
||||||
|
)
|
||||||
|
features = list(
|
||||||
|
tqdm(
|
||||||
|
p.imap(annotate_, examples, chunksize=32),
|
||||||
|
total=len(examples),
|
||||||
|
desc="convert squad examples to features",
|
||||||
|
)
|
||||||
|
)
|
||||||
new_features = []
|
new_features = []
|
||||||
unique_id = 1000000000
|
unique_id = 1000000000
|
||||||
example_index = 0
|
example_index = 0
|
||||||
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'):
|
for example_features in tqdm(features, total=len(features), desc="add example index and unique id"):
|
||||||
if not example_features:
|
if not example_features:
|
||||||
continue
|
continue
|
||||||
for example_feature in example_features:
|
for example_feature in example_features:
|
||||||
@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
example_index += 1
|
example_index += 1
|
||||||
features = new_features
|
features = new_features
|
||||||
del new_features
|
del new_features
|
||||||
if return_dataset == 'pt':
|
if return_dataset == "pt":
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||||
|
|
||||||
@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
"input_ids": ex.input_ids,
|
"input_ids": ex.input_ids,
|
||||||
"attention_mask": ex.attention_mask,
|
"attention_mask": ex.attention_mask,
|
||||||
"token_type_ids": ex.token_type_ids,
|
"token_type_ids": ex.token_type_ids,
|
||||||
}, {
|
},
|
||||||
|
{
|
||||||
"start_position": ex.start_position,
|
"start_position": ex.start_position,
|
||||||
"end_position": ex.end_position,
|
"end_position": ex.end_position,
|
||||||
"cls_index": ex.cls_index,
|
"cls_index": ex.cls_index,
|
||||||
"p_mask": ex.p_mask,
|
"p_mask": ex.p_mask,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return tf.data.Dataset.from_generator(
|
return tf.data.Dataset.from_generator(
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from ...file_utils import is_tf_available, is_torch_available
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InputExample(object):
|
class InputExample(object):
|
||||||
"""
|
"""
|
||||||
A single training/test example for simple sequence classification.
|
A single training/test example for simple sequence classification.
|
||||||
@@ -37,6 +38,7 @@ class InputExample(object):
|
|||||||
label: (Optional) string. The label of the example. This should be
|
label: (Optional) string. The label of the example. This should be
|
||||||
specified for train and dev examples, but not for test examples.
|
specified for train and dev examples, but not for test examples.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||||
self.guid = guid
|
self.guid = guid
|
||||||
self.text_a = text_a
|
self.text_a = text_a
|
||||||
@@ -99,14 +101,15 @@ class DataProcessor(object):
|
|||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
line = list(unicode(cell, "utf-8") for cell in line)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
class SingleSentenceClassificationProcessor(DataProcessor):
|
class SingleSentenceClassificationProcessor(DataProcessor):
|
||||||
""" Generic processor for a single sentence classification data set."""
|
""" Generic processor for a single sentence classification data set."""
|
||||||
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
|
|
||||||
|
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
|
||||||
self.labels = [] if labels is None else labels
|
self.labels = [] if labels is None else labels
|
||||||
self.examples = [] if examples is None else examples
|
self.examples = [] if examples is None else examples
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
@@ -117,22 +120,24 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return SingleSentenceClassificationProcessor(labels=self.labels,
|
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
|
||||||
examples=self.examples[idx])
|
|
||||||
return self.examples[idx]
|
return self.examples[idx]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
|
def create_from_csv(
|
||||||
column_id=None, skip_first_row=False, **kwargs):
|
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
|
||||||
|
):
|
||||||
processor = cls(**kwargs)
|
processor = cls(**kwargs)
|
||||||
processor.add_examples_from_csv(file_name,
|
processor.add_examples_from_csv(
|
||||||
|
file_name,
|
||||||
split_name=split_name,
|
split_name=split_name,
|
||||||
column_label=column_label,
|
column_label=column_label,
|
||||||
column_text=column_text,
|
column_text=column_text,
|
||||||
column_id=column_id,
|
column_id=column_id,
|
||||||
skip_first_row=skip_first_row,
|
skip_first_row=skip_first_row,
|
||||||
overwrite_labels=True,
|
overwrite_labels=True,
|
||||||
overwrite_examples=True)
|
overwrite_examples=True,
|
||||||
|
)
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -141,8 +146,17 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
|
def add_examples_from_csv(
|
||||||
skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
|
self,
|
||||||
|
file_name,
|
||||||
|
split_name="",
|
||||||
|
column_label=0,
|
||||||
|
column_text=1,
|
||||||
|
column_id=None,
|
||||||
|
skip_first_row=False,
|
||||||
|
overwrite_labels=False,
|
||||||
|
overwrite_examples=False,
|
||||||
|
):
|
||||||
lines = self._read_tsv(file_name)
|
lines = self._read_tsv(file_name)
|
||||||
if skip_first_row:
|
if skip_first_row:
|
||||||
lines = lines[1:]
|
lines = lines[1:]
|
||||||
@@ -158,10 +172,13 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
|
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
|
||||||
ids.append(guid)
|
ids.append(guid)
|
||||||
|
|
||||||
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
|
return self.add_examples(
|
||||||
|
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
|
||||||
|
)
|
||||||
|
|
||||||
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
|
def add_examples(
|
||||||
overwrite_labels=False, overwrite_examples=False):
|
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
|
||||||
|
):
|
||||||
assert labels is None or len(texts_or_text_and_labels) == len(labels)
|
assert labels is None or len(texts_or_text_and_labels) == len(labels)
|
||||||
assert ids is None or len(texts_or_text_and_labels) == len(ids)
|
assert ids is None or len(texts_or_text_and_labels) == len(ids)
|
||||||
if ids is None:
|
if ids is None:
|
||||||
@@ -192,13 +209,15 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
|
|
||||||
return self.examples
|
return self.examples
|
||||||
|
|
||||||
def get_features(self,
|
def get_features(
|
||||||
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
pad_on_left=False,
|
pad_on_left=False,
|
||||||
pad_token=0,
|
pad_token=0,
|
||||||
mask_padding_with_zero=True,
|
mask_padding_with_zero=True,
|
||||||
return_tensors=None):
|
return_tensors=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Convert examples in a list of ``InputFeatures``
|
Convert examples in a list of ``InputFeatures``
|
||||||
|
|
||||||
@@ -231,9 +250,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
logger.info("Tokenizing example %d", ex_index)
|
logger.info("Tokenizing example %d", ex_index)
|
||||||
|
|
||||||
input_ids = tokenizer.encode(
|
input_ids = tokenizer.encode(
|
||||||
example.text_a,
|
example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len),
|
||||||
add_special_tokens=True,
|
|
||||||
max_length=min(max_length, tokenizer.max_len),
|
|
||||||
)
|
)
|
||||||
all_input_ids.append(input_ids)
|
all_input_ids.append(input_ids)
|
||||||
|
|
||||||
@@ -256,8 +273,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
input_ids = input_ids + ([pad_token] * padding_length)
|
input_ids = input_ids + ([pad_token] * padding_length)
|
||||||
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||||
|
|
||||||
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length)
|
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(
|
||||||
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length)
|
len(input_ids), batch_length
|
||||||
|
)
|
||||||
|
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(
|
||||||
|
len(attention_mask), batch_length
|
||||||
|
)
|
||||||
|
|
||||||
if self.mode == "classification":
|
if self.mode == "classification":
|
||||||
label = label_map[example.label]
|
label = label_map[example.label]
|
||||||
@@ -273,36 +294,31 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
|
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
|
||||||
logger.info("label: %s (id = %d)" % (example.label, label))
|
logger.info("label: %s (id = %d)" % (example.label, label))
|
||||||
|
|
||||||
features.append(
|
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
|
||||||
InputFeatures(input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
label=label))
|
|
||||||
|
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
return features
|
return features
|
||||||
elif return_tensors == 'tf':
|
elif return_tensors == "tf":
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for ex in features:
|
||||||
yield ({'input_ids': ex.input_ids,
|
yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
|
||||||
'attention_mask': ex.attention_mask},
|
|
||||||
ex.label)
|
|
||||||
|
|
||||||
dataset = tf.data.Dataset.from_generator(gen,
|
dataset = tf.data.Dataset.from_generator(
|
||||||
({'input_ids': tf.int32,
|
gen,
|
||||||
'attention_mask': tf.int32},
|
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
|
||||||
tf.int64),
|
({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
|
||||||
({'input_ids': tf.TensorShape([None]),
|
)
|
||||||
'attention_mask': tf.TensorShape([None])},
|
|
||||||
tf.TensorShape([])))
|
|
||||||
return dataset
|
return dataset
|
||||||
elif return_tensors == 'pt':
|
elif return_tensors == "pt":
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
|
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset
|
from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||||
if self.mode == "classification":
|
if self.mode == "classification":
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .utils import DataProcessor, InputExample
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class XnliProcessor(DataProcessor):
|
class XnliProcessor(DataProcessor):
|
||||||
"""Processor for the XNLI dataset.
|
"""Processor for the XNLI dataset.
|
||||||
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
|
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
|
||||||
@@ -40,13 +41,12 @@ class XnliProcessor(DataProcessor):
|
|||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % ('train', i)
|
guid = "%s-%s" % ("train", i)
|
||||||
text_a = line[0]
|
text_a = line[0]
|
||||||
text_b = line[1]
|
text_b = line[1]
|
||||||
label = "contradiction" if line[2] == "contradictory" else line[2]
|
label = "contradiction" if line[2] == "contradictory" else line[2]
|
||||||
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
def get_test_examples(self, data_dir):
|
def get_test_examples(self, data_dir):
|
||||||
@@ -59,19 +59,19 @@ class XnliProcessor(DataProcessor):
|
|||||||
language = line[0]
|
language = line[0]
|
||||||
if language != self.language:
|
if language != self.language:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % ('test', i)
|
guid = "%s-%s" % ("test", i)
|
||||||
text_a = line[6]
|
text_a = line[6]
|
||||||
text_b = line[7]
|
text_b = line[7]
|
||||||
label = line[1]
|
label = line[1]
|
||||||
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
||||||
examples.append(
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["contradiction", "entailment", "neutral"]
|
return ["contradiction", "entailment", "neutral"]
|
||||||
|
|
||||||
|
|
||||||
xnli_processors = {
|
xnli_processors = {
|
||||||
"xnli": XnliProcessor,
|
"xnli": XnliProcessor,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache.
|
|||||||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||||
Copyright by the AllenNLP authors.
|
Copyright by the AllenNLP authors.
|
||||||
"""
|
"""
|
||||||
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
@@ -29,9 +29,10 @@ from filelock import FileLock
|
|||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.environ.setdefault('USE_TORCH', 'YES')
|
os.environ.setdefault("USE_TORCH", "YES")
|
||||||
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
|
if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_torch_available = True # pylint: disable=invalid-name
|
_torch_available = True # pylint: disable=invalid-name
|
||||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||||
else:
|
else:
|
||||||
@@ -41,10 +42,11 @@ except ImportError:
|
|||||||
_torch_available = False # pylint: disable=invalid-name
|
_torch_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.environ.setdefault('USE_TF', 'YES')
|
os.environ.setdefault("USE_TF", "YES")
|
||||||
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
|
if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
|
|
||||||
|
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||||
_tf_available = True # pylint: disable=invalid-name
|
_tf_available = True # pylint: disable=invalid-name
|
||||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||||
else:
|
else:
|
||||||
@@ -55,12 +57,13 @@ except (ImportError, AssertionError):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.hub import _get_torch_home
|
from torch.hub import _get_torch_home
|
||||||
|
|
||||||
torch_cache_home = _get_torch_home()
|
torch_cache_home = _get_torch_home()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
torch_cache_home = os.path.expanduser(
|
torch_cache_home = os.path.expanduser(
|
||||||
os.getenv('TORCH_HOME', os.path.join(
|
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||||
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
|
)
|
||||||
default_cache_path = os.path.join(torch_cache_home, 'transformers')
|
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -69,19 +72,21 @@ except ImportError:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||||
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
|
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
|
||||||
|
)
|
||||||
except (AttributeError, ImportError):
|
except (AttributeError, ImportError):
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
|
||||||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||||
default_cache_path))
|
)
|
||||||
|
|
||||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||||
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
TF2_WEIGHTS_NAME = 'tf_model.h5'
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = "model.ckpt"
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
MODEL_CARD_NAME = "modelcard.json"
|
MODEL_CARD_NAME = "modelcard.json"
|
||||||
|
|
||||||
@@ -95,38 +100,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
|
|||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
|
|
||||||
def is_tf_available():
|
def is_tf_available():
|
||||||
|
|
||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|
||||||
if not six.PY2:
|
if not six.PY2:
|
||||||
|
|
||||||
def add_start_docstrings(*docstr):
|
def add_start_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
fn.__doc__ = ''.join(docstr) + fn.__doc__
|
fn.__doc__ = "".join(docstr) + fn.__doc__
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return docstring_decorator
|
return docstring_decorator
|
||||||
|
|
||||||
def add_end_docstrings(*docstr):
|
def add_end_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
fn.__doc__ = fn.__doc__ + ''.join(docstr)
|
fn.__doc__ = fn.__doc__ + "".join(docstr)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return docstring_decorator
|
return docstring_decorator
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Not possible to update class docstrings on python2
|
# Not possible to update class docstrings on python2
|
||||||
def add_start_docstrings(*docstr):
|
def add_start_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return docstring_decorator
|
return docstring_decorator
|
||||||
|
|
||||||
def add_end_docstrings(*docstr):
|
def add_end_docstrings(*docstr):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return docstring_decorator
|
return docstring_decorator
|
||||||
|
|
||||||
|
|
||||||
def is_remote_url(url_or_filename):
|
def is_remote_url(url_or_filename):
|
||||||
parsed = urlparse(url_or_filename)
|
parsed = urlparse(url_or_filename)
|
||||||
return parsed.scheme in ('http', 'https', 's3')
|
return parsed.scheme in ("http", "https", "s3")
|
||||||
|
|
||||||
|
|
||||||
def hf_bucket_url(identifier, postfix=None, cdn=False):
|
def hf_bucket_url(identifier, postfix=None, cdn=False):
|
||||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
||||||
@@ -145,17 +160,17 @@ def url_to_filename(url, etag=None):
|
|||||||
so that TF 2.0 can identify it as a HDF5 file
|
so that TF 2.0 can identify it as a HDF5 file
|
||||||
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
||||||
"""
|
"""
|
||||||
url_bytes = url.encode('utf-8')
|
url_bytes = url.encode("utf-8")
|
||||||
url_hash = sha256(url_bytes)
|
url_hash = sha256(url_bytes)
|
||||||
filename = url_hash.hexdigest()
|
filename = url_hash.hexdigest()
|
||||||
|
|
||||||
if etag:
|
if etag:
|
||||||
etag_bytes = etag.encode('utf-8')
|
etag_bytes = etag.encode("utf-8")
|
||||||
etag_hash = sha256(etag_bytes)
|
etag_hash = sha256(etag_bytes)
|
||||||
filename += '.' + etag_hash.hexdigest()
|
filename += "." + etag_hash.hexdigest()
|
||||||
|
|
||||||
if url.endswith('.h5'):
|
if url.endswith(".h5"):
|
||||||
filename += '.h5'
|
filename += ".h5"
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
@@ -174,19 +189,21 @@ def filename_to_url(filename, cache_dir=None):
|
|||||||
if not os.path.exists(cache_path):
|
if not os.path.exists(cache_path):
|
||||||
raise EnvironmentError("file {} not found".format(cache_path))
|
raise EnvironmentError("file {} not found".format(cache_path))
|
||||||
|
|
||||||
meta_path = cache_path + '.json'
|
meta_path = cache_path + ".json"
|
||||||
if not os.path.exists(meta_path):
|
if not os.path.exists(meta_path):
|
||||||
raise EnvironmentError("file {} not found".format(meta_path))
|
raise EnvironmentError("file {} not found".format(meta_path))
|
||||||
|
|
||||||
with open(meta_path, encoding="utf-8") as meta_file:
|
with open(meta_path, encoding="utf-8") as meta_file:
|
||||||
metadata = json.load(meta_file)
|
metadata = json.load(meta_file)
|
||||||
url = metadata['url']
|
url = metadata["url"]
|
||||||
etag = metadata['etag']
|
etag = metadata["etag"]
|
||||||
|
|
||||||
return url, etag
|
return url, etag
|
||||||
|
|
||||||
|
|
||||||
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
|
def cached_path(
|
||||||
|
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Given something that might be a URL (or might be a local path),
|
Given something that might be a URL (or might be a local path),
|
||||||
determine which. If it's a URL, download the file and cache it, and
|
determine which. If it's a URL, download the file and cache it, and
|
||||||
@@ -207,13 +224,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
|
|||||||
|
|
||||||
if is_remote_url(url_or_filename):
|
if is_remote_url(url_or_filename):
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
# URL, so get it from the cache (downloading if necessary)
|
||||||
return get_from_cache(url_or_filename, cache_dir=cache_dir,
|
return get_from_cache(
|
||||||
force_download=force_download, proxies=proxies,
|
url_or_filename,
|
||||||
resume_download=resume_download, user_agent=user_agent)
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
elif os.path.exists(url_or_filename):
|
elif os.path.exists(url_or_filename):
|
||||||
# File, and it exists.
|
# File, and it exists.
|
||||||
return url_or_filename
|
return url_or_filename
|
||||||
elif urlparse(url_or_filename).scheme == '':
|
elif urlparse(url_or_filename).scheme == "":
|
||||||
# File, but it doesn't exist.
|
# File, but it doesn't exist.
|
||||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||||
else:
|
else:
|
||||||
@@ -273,23 +295,25 @@ def s3_get(url, temp_file, proxies=None):
|
|||||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||||
if isinstance(user_agent, dict):
|
if isinstance(user_agent, dict):
|
||||||
ua += "; " + "; ".join(
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||||
"{}/{}".format(k, v) for k, v in user_agent.items()
|
|
||||||
)
|
|
||||||
elif isinstance(user_agent, six.string_types):
|
elif isinstance(user_agent, six.string_types):
|
||||||
ua += "; " + user_agent
|
ua += "; " + user_agent
|
||||||
headers = {
|
headers = {"user-agent": ua}
|
||||||
"user-agent": ua
|
|
||||||
}
|
|
||||||
if resume_size > 0:
|
if resume_size > 0:
|
||||||
headers['Range'] = 'bytes=%d-' % (resume_size,)
|
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||||
if response.status_code == 416: # Range not satisfiable
|
if response.status_code == 416: # Range not satisfiable
|
||||||
return
|
return
|
||||||
content_length = response.headers.get('Content-Length')
|
content_length = response.headers.get("Content-Length")
|
||||||
total = resume_size + int(content_length) if content_length is not None else None
|
total = resume_size + int(content_length) if content_length is not None else None
|
||||||
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size,
|
progress = tqdm(
|
||||||
desc="Downloading", disable=bool(logger.level<=logging.INFO))
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
total=total,
|
||||||
|
initial=resume_size,
|
||||||
|
desc="Downloading",
|
||||||
|
disable=bool(logger.level <= logging.INFO),
|
||||||
|
)
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
if chunk: # filter out keep-alive new chunks
|
if chunk: # filter out keep-alive new chunks
|
||||||
progress.update(len(chunk))
|
progress.update(len(chunk))
|
||||||
@@ -297,7 +321,9 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
|||||||
progress.close()
|
progress.close()
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
|
def get_from_cache(
|
||||||
|
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding dataset in the local cache.
|
Given a URL, look for the corresponding dataset in the local cache.
|
||||||
If it's not there, download it. Then return the path to the cached file.
|
If it's not there, download it. Then return the path to the cached file.
|
||||||
@@ -326,7 +352,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
|||||||
etag = None
|
etag = None
|
||||||
|
|
||||||
if sys.version_info[0] == 2 and etag is not None:
|
if sys.version_info[0] == 2 and etag is not None:
|
||||||
etag = etag.decode('utf-8')
|
etag = etag.decode("utf-8")
|
||||||
filename = url_to_filename(url, etag)
|
filename = url_to_filename(url, etag)
|
||||||
|
|
||||||
# get cache path to put the file
|
# get cache path to put the file
|
||||||
@@ -337,22 +363,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
|||||||
if not os.path.exists(cache_path) and etag is None:
|
if not os.path.exists(cache_path) and etag is None:
|
||||||
matching_files = [
|
matching_files = [
|
||||||
file
|
file
|
||||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||||
if not file.endswith('.json') and not file.endswith('.lock')
|
if not file.endswith(".json") and not file.endswith(".lock")
|
||||||
]
|
]
|
||||||
if matching_files:
|
if matching_files:
|
||||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||||
|
|
||||||
# Prevent parallel downloads of the same file with a lock.
|
# Prevent parallel downloads of the same file with a lock.
|
||||||
lock_path = cache_path + '.lock'
|
lock_path = cache_path + ".lock"
|
||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if resume_download:
|
if resume_download:
|
||||||
incomplete_path = cache_path + '.incomplete'
|
incomplete_path = cache_path + ".incomplete"
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _resumable_file_manager():
|
def _resumable_file_manager():
|
||||||
with open(incomplete_path,'a+b') as f:
|
with open(incomplete_path, "a+b") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
temp_file_manager = _resumable_file_manager
|
temp_file_manager = _resumable_file_manager
|
||||||
if os.path.exists(incomplete_path):
|
if os.path.exists(incomplete_path):
|
||||||
resume_size = os.stat(incomplete_path).st_size
|
resume_size = os.stat(incomplete_path).st_size
|
||||||
@@ -366,7 +394,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
|||||||
# Download to temporary file, then copy to cache dir once finished.
|
# Download to temporary file, then copy to cache dir once finished.
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||||
with temp_file_manager() as temp_file:
|
with temp_file_manager() as temp_file:
|
||||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
logger.info(
|
||||||
|
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
|
||||||
|
)
|
||||||
|
|
||||||
# GET file object
|
# GET file object
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
@@ -383,12 +413,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
|||||||
os.rename(temp_file.name, cache_path)
|
os.rename(temp_file.name, cache_path)
|
||||||
|
|
||||||
logger.info("creating metadata file for %s", cache_path)
|
logger.info("creating metadata file for %s", cache_path)
|
||||||
meta = {'url': url, 'etag': etag}
|
meta = {"url": url, "etag": etag}
|
||||||
meta_path = cache_path + '.json'
|
meta_path = cache_path + ".json"
|
||||||
with open(meta_path, 'w') as meta_file:
|
with open(meta_path, "w") as meta_file:
|
||||||
output_string = json.dumps(meta)
|
output_string = json.dumps(meta)
|
||||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||||
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
|
output_string = unicode(output_string, "utf-8") # The beauty of python 2
|
||||||
meta_file.write(output_string)
|
meta_file.write(output_string)
|
||||||
|
|
||||||
return cache_path
|
return cache_path
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
ENDPOINT = "https://huggingface.co"
|
ENDPOINT = "https://huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
class S3Obj:
|
class S3Obj:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -78,8 +79,7 @@ class HfApi:
|
|||||||
return d["token"]
|
return d["token"]
|
||||||
|
|
||||||
def whoami(
|
def whoami(
|
||||||
self,
|
self, token, # type: str
|
||||||
token, # type: str
|
|
||||||
):
|
):
|
||||||
# type: (...) -> str
|
# type: (...) -> str
|
||||||
"""
|
"""
|
||||||
@@ -106,11 +106,7 @@ class HfApi:
|
|||||||
Call HF API to get a presigned url to upload `filename` to S3.
|
Call HF API to get a presigned url to upload `filename` to S3.
|
||||||
"""
|
"""
|
||||||
path = "{}/api/presign".format(self.endpoint)
|
path = "{}/api/presign".format(self.endpoint)
|
||||||
r = requests.post(
|
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename},)
|
||||||
path,
|
|
||||||
headers={"authorization": "Bearer {}".format(token)},
|
|
||||||
json={"filename": filename},
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
return PresignedUrl(**d)
|
return PresignedUrl(**d)
|
||||||
@@ -133,9 +129,7 @@ class HfApi:
|
|||||||
pf = TqdmProgressFileReader(f)
|
pf = TqdmProgressFileReader(f)
|
||||||
data = f if pf.total_size > 0 else ""
|
data = f if pf.total_size > 0 else ""
|
||||||
|
|
||||||
r = requests.put(urls.write, data=data, headers={
|
r = requests.put(urls.write, data=data, headers={"content-type": urls.type,})
|
||||||
"content-type": urls.type,
|
|
||||||
})
|
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
pf.close()
|
pf.close()
|
||||||
return urls.access
|
return urls.access
|
||||||
@@ -152,7 +146,6 @@ class HfApi:
|
|||||||
return [S3Obj(**x) for x in d]
|
return [S3Obj(**x) for x in d]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TqdmProgressFileReader:
|
class TqdmProgressFileReader:
|
||||||
"""
|
"""
|
||||||
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
|
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
|
||||||
@@ -161,9 +154,9 @@ class TqdmProgressFileReader:
|
|||||||
see github.com/huggingface/transformers/pull/2078#discussion_r354739608
|
see github.com/huggingface/transformers/pull/2078#discussion_r354739608
|
||||||
for implementation details.
|
for implementation details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, f # type: io.BufferedReader
|
||||||
f # type: io.BufferedReader
|
|
||||||
):
|
):
|
||||||
self.f = f
|
self.f = f
|
||||||
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
||||||
@@ -182,7 +175,6 @@ class TqdmProgressFileReader:
|
|||||||
self.pbar.close()
|
self.pbar.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HfFolder:
|
class HfFolder:
|
||||||
path_token = expanduser("~/.huggingface/token")
|
path_token = expanduser("~/.huggingface/token")
|
||||||
|
|
||||||
@@ -201,7 +193,7 @@ class HfFolder:
|
|||||||
if e.errno != os.errno.EEXIST:
|
if e.errno != os.errno.EEXIST:
|
||||||
raise e
|
raise e
|
||||||
pass
|
pass
|
||||||
with open(cls.path_token, 'w+') as f:
|
with open(cls.path_token, "w+") as f:
|
||||||
f.write(token)
|
f.write(token)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -210,7 +202,7 @@ class HfFolder:
|
|||||||
Get token or None if not existent.
|
Get token or None if not existent.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(cls.path_token, 'r') as f:
|
with open(cls.path_token, "r") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except:
|
except:
|
||||||
# this is too wide. When Py2 is dead use:
|
# this is too wide. When Py2 is dead use:
|
||||||
|
|||||||
@@ -14,8 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Configuration base class and utilities."""
|
""" Configuration base class and utilities."""
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
@@ -25,8 +24,15 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \
|
from .file_utils import (
|
||||||
cached_path, is_remote_url, hf_bucket_url
|
CONFIG_NAME,
|
||||||
|
MODEL_CARD_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
cached_path,
|
||||||
|
is_remote_url,
|
||||||
|
hf_bucket_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -48,17 +54,18 @@ class ModelCard(object):
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
|
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
|
||||||
self.model_details = kwargs.pop('model_details', {})
|
self.model_details = kwargs.pop("model_details", {})
|
||||||
self.intended_use = kwargs.pop('intended_use', {})
|
self.intended_use = kwargs.pop("intended_use", {})
|
||||||
self.factors = kwargs.pop('factors', {})
|
self.factors = kwargs.pop("factors", {})
|
||||||
self.metrics = kwargs.pop('metrics', {})
|
self.metrics = kwargs.pop("metrics", {})
|
||||||
self.evaluation_data = kwargs.pop('evaluation_data', {})
|
self.evaluation_data = kwargs.pop("evaluation_data", {})
|
||||||
self.training_data = kwargs.pop('training_data', {})
|
self.training_data = kwargs.pop("training_data", {})
|
||||||
self.quantitative_analyses = kwargs.pop('quantitative_analyses', {})
|
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
|
||||||
self.ethical_considerations = kwargs.pop('ethical_considerations', {})
|
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
|
||||||
self.caveats_and_recommendations = kwargs.pop('caveats_and_recommendations', {})
|
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
|
||||||
|
|
||||||
# Open additional attributes
|
# Open additional attributes
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
@@ -122,10 +129,10 @@ class ModelCard(object):
|
|||||||
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
proxies = kwargs.pop('proxies', None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
find_from_standard_name = kwargs.pop('find_from_standard_name', True)
|
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
|
||||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||||
# For simplicity we use the same pretrained url than the configuration files
|
# For simplicity we use the same pretrained url than the configuration files
|
||||||
@@ -145,36 +152,43 @@ class ModelCard(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True,
|
resolved_model_card_file = cached_path(
|
||||||
proxies=proxies, resume_download=False)
|
model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
|
||||||
|
)
|
||||||
if resolved_model_card_file == model_card_file:
|
if resolved_model_card_file == model_card_file:
|
||||||
logger.info("loading model card file {}".format(model_card_file))
|
logger.info("loading model card file {}".format(model_card_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading model card file {} from cache at {}".format(
|
logger.info(
|
||||||
model_card_file, resolved_model_card_file))
|
"loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
|
||||||
|
)
|
||||||
# Load model card
|
# Load model card
|
||||||
modelcard = cls.from_json_file(resolved_model_card_file)
|
modelcard = cls.from_json_file(resolved_model_card_file)
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||||
logger.warning("Couldn't reach server at '{}' to download model card file.".format(
|
logger.warning("Couldn't reach server at '{}' to download model card file.".format(model_card_file))
|
||||||
model_card_file))
|
|
||||||
else:
|
else:
|
||||||
logger.warning("Model name '{}' was not found in model name list ({}). " \
|
logger.warning(
|
||||||
"We assumed '{}' was a path or url to a model card file named {} or " \
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url to a model card file named {} or "
|
||||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
|
", ".join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
|
||||||
model_card_file, MODEL_CARD_NAME))
|
model_card_file,
|
||||||
|
MODEL_CARD_NAME,
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.warning("Creating an empty model card.")
|
logger.warning("Creating an empty model card.")
|
||||||
|
|
||||||
# We fall back on creating an empty model card
|
# We fall back on creating an empty model card
|
||||||
modelcard = cls()
|
modelcard = cls()
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning("Couldn't reach server at '{}' to download model card file or "
|
logger.warning(
|
||||||
|
"Couldn't reach server at '{}' to download model card file or "
|
||||||
"model card file is not a valid JSON file. "
|
"model card file is not a valid JSON file. "
|
||||||
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file))
|
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file)
|
||||||
|
)
|
||||||
logger.warning("Creating an empty model card.")
|
logger.warning("Creating an empty model card.")
|
||||||
|
|
||||||
# We fall back on creating an empty model card
|
# We fall back on creating an empty model card
|
||||||
@@ -203,7 +217,7 @@ class ModelCard(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_json_file(cls, json_file):
|
def from_json_file(cls, json_file):
|
||||||
"""Constructs a `ModelCard` from a json file of parameters."""
|
"""Constructs a `ModelCard` from a json file of parameters."""
|
||||||
with open(json_file, "r", encoding='utf-8') as reader:
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
dict_obj = json.loads(text)
|
dict_obj = json.loads(text)
|
||||||
return cls(**dict_obj)
|
return cls(**dict_obj)
|
||||||
@@ -225,5 +239,5 @@ class ModelCard(object):
|
|||||||
|
|
||||||
def to_json_file(self, json_file_path):
|
def to_json_file(self, json_file_path):
|
||||||
""" Save this instance to a json file."""
|
""" Save this instance to a json file."""
|
||||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||||
writer.write(self.to_json_string())
|
writer.write(self.to_json_string())
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
@@ -30,14 +29,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
||||||
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
||||||
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
||||||
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
||||||
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
||||||
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
||||||
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
||||||
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -48,8 +47,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
logger.error(
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||||
@@ -109,7 +110,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
if "seq_relationship" in name:
|
if "seq_relationship" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = name.split('/')
|
name = name.split("/")
|
||||||
|
|
||||||
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
||||||
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
||||||
@@ -118,19 +119,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
|
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||||
l = re.split(r'_(\d+)', m_name)
|
l = re.split(r"_(\d+)", m_name)
|
||||||
else:
|
else:
|
||||||
l = [m_name]
|
l = [m_name]
|
||||||
|
|
||||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
if l[0] == "kernel" or l[0] == "gamma":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
elif l[0] == "output_bias" or l[0] == "beta":
|
||||||
pointer = getattr(pointer, 'bias')
|
pointer = getattr(pointer, "bias")
|
||||||
elif l[0] == 'output_weights':
|
elif l[0] == "output_weights":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif l[0] == 'squad':
|
elif l[0] == "squad":
|
||||||
pointer = getattr(pointer, 'classifier')
|
pointer = getattr(pointer, "classifier")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
pointer = getattr(pointer, l[0])
|
pointer = getattr(pointer, l[0])
|
||||||
@@ -141,9 +142,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
num = int(l[1])
|
num = int(l[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
|
|
||||||
if m_name[-11:] == '_embeddings':
|
if m_name[-11:] == "_embeddings":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif m_name == 'kernel':
|
elif m_name == "kernel":
|
||||||
array = np.transpose(array)
|
array = np.transpose(array)
|
||||||
try:
|
try:
|
||||||
assert pointer.shape == array.shape
|
assert pointer.shape == array.shape
|
||||||
@@ -160,6 +161,7 @@ class AlbertEmbeddings(BertEmbeddings):
|
|||||||
"""
|
"""
|
||||||
Construct the embeddings from word, position and token_type embeddings.
|
Construct the embeddings from word, position and token_type embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(AlbertEmbeddings, self).__init__(config)
|
super(AlbertEmbeddings, self).__init__(config)
|
||||||
|
|
||||||
@@ -238,9 +240,12 @@ class AlbertAttention(BertSelfAttention):
|
|||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
reshaped_context_layer = context_layer.view(*new_context_layer_shape)
|
reshaped_context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
|
||||||
# Should find a better way to do this
|
# Should find a better way to do this
|
||||||
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)
|
w = (
|
||||||
|
self.dense.weight.t()
|
||||||
|
.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
|
||||||
|
.to(context_layer.dtype)
|
||||||
|
)
|
||||||
b = self.dense.bias.to(context_layer.dtype)
|
b = self.dense.bias.to(context_layer.dtype)
|
||||||
|
|
||||||
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
|
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
|
||||||
@@ -328,7 +333,11 @@ class AlbertTransformer(nn.Module):
|
|||||||
# Index of the layer inside the group
|
# Index of the layer inside the group
|
||||||
layer_idx = int(i - group_idx * layers_per_group)
|
layer_idx = int(i - group_idx * layers_per_group)
|
||||||
|
|
||||||
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
|
layer_group_output = self.albert_layer_groups[group_idx](
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||||
|
)
|
||||||
hidden_states = layer_group_output[0]
|
hidden_states = layer_group_output[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
@@ -337,7 +346,6 @@ class AlbertTransformer(nn.Module):
|
|||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
@@ -346,11 +354,11 @@ class AlbertTransformer(nn.Module):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AlbertPreTrainedModel(PreTrainedModel):
|
class AlbertPreTrainedModel(PreTrainedModel):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = AlbertConfig
|
config_class = AlbertConfig
|
||||||
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
base_model_prefix = "albert"
|
base_model_prefix = "albert"
|
||||||
@@ -431,8 +439,12 @@ ALBERT_INPUTS_DOCSTRING = r"""
|
|||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
ALBERT_START_DOCSTRING,
|
||||||
|
ALBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class AlbertModel(AlbertPreTrainedModel):
|
class AlbertModel(AlbertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -500,8 +512,15 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
||||||
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(
|
||||||
inputs_embeds=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
):
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
@@ -527,24 +546,30 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = (
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(
|
||||||
|
dtype=next(self.parameters()).dtype
|
||||||
|
) # switch to fload if need + fp16 compatibility
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
embedding_output = self.embeddings(
|
||||||
inputs_embeds=inputs_embeds)
|
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
encoder_outputs = self.encoder(embedding_output,
|
)
|
||||||
extended_attention_mask,
|
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||||
head_mask=head_mask)
|
|
||||||
|
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
|
|
||||||
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
outputs = (sequence_output, pooled_output) + encoder_outputs[
|
||||||
|
1:
|
||||||
|
] # add hidden_states and attentions if they are here
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class AlbertMLMHead(nn.Module):
|
class AlbertMLMHead(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(AlbertMLMHead, self).__init__()
|
super(AlbertMLMHead, self).__init__()
|
||||||
@@ -566,7 +591,9 @@ class AlbertMLMHead(nn.Module):
|
|||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING
|
||||||
|
)
|
||||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -602,21 +629,28 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
|||||||
""" Make sure we are sharing the input and output embeddings.
|
""" Make sure we are sharing the input and output embeddings.
|
||||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||||
"""
|
"""
|
||||||
self._tie_or_clone_weights(self.predictions.decoder,
|
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
||||||
self.albert.embeddings.word_embeddings)
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.predictions.decoder
|
return self.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
masked_lm_labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
masked_lm_labels=None,
|
||||||
|
):
|
||||||
outputs = self.albert(
|
outputs = self.albert(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
sequence_outputs = outputs[0]
|
sequence_outputs = outputs[0]
|
||||||
|
|
||||||
@@ -631,9 +665,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
ALBERT_START_DOCSTRING,
|
||||||
|
ALBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -665,6 +702,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(AlbertForSequenceClassification, self).__init__(config)
|
super(AlbertForSequenceClassification, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -675,8 +713,16 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
def forward(
|
||||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.albert(
|
outputs = self.albert(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -684,7 +730,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
@@ -707,10 +753,12 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
@add_start_docstrings("""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
|
ALBERT_START_DOCSTRING,
|
||||||
|
ALBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -752,6 +800,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(AlbertForQuestionAnswering, self).__init__(config)
|
super(AlbertForQuestionAnswering, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -761,8 +810,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(
|
||||||
inputs_embeds=None, start_positions=None, end_positions=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.albert(
|
outputs = self.albert(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -770,7 +828,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
|||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|||||||
@@ -18,31 +18,87 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig,
|
from .configuration_auto import (
|
||||||
DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig,
|
AlbertConfig,
|
||||||
TransfoXLConfig, XLMConfig, XLNetConfig, XLMRobertaConfig)
|
BertConfig,
|
||||||
|
CamembertConfig,
|
||||||
|
CTRLConfig,
|
||||||
|
DistilBertConfig,
|
||||||
|
GPT2Config,
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
RobertaConfig,
|
||||||
|
TransfoXLConfig,
|
||||||
|
XLMConfig,
|
||||||
|
XLNetConfig,
|
||||||
|
XLMRobertaConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \
|
from .modeling_bert import (
|
||||||
BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
BertModel,
|
||||||
|
BertForMaskedLM,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BertForQuestionAnswering,
|
||||||
|
BertForTokenClassification,
|
||||||
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \
|
from .modeling_xlnet import (
|
||||||
XLNetForTokenClassification, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
XLNetModel,
|
||||||
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering, \
|
XLNetLMHeadModel,
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
XLNetForSequenceClassification,
|
||||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, \
|
XLNetForQuestionAnswering,
|
||||||
RobertaForTokenClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
XLNetForTokenClassification,
|
||||||
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, \
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForSequenceClassification, DistilBertForTokenClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
)
|
||||||
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \
|
from .modeling_xlm import (
|
||||||
CamembertForMultipleChoice, CamembertForTokenClassification, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
XLMModel,
|
||||||
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \
|
XLMWithLMHeadModel,
|
||||||
AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
XLMForSequenceClassification,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_roberta import (
|
||||||
|
RobertaModel,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
RobertaForTokenClassification,
|
||||||
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_distilbert import (
|
||||||
|
DistilBertModel,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertForTokenClassification,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_camembert import (
|
||||||
|
CamembertModel,
|
||||||
|
CamembertForMaskedLM,
|
||||||
|
CamembertForSequenceClassification,
|
||||||
|
CamembertForMultipleChoice,
|
||||||
|
CamembertForTokenClassification,
|
||||||
|
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
from .modeling_albert import (
|
||||||
|
AlbertModel,
|
||||||
|
AlbertForMaskedLM,
|
||||||
|
AlbertForSequenceClassification,
|
||||||
|
AlbertForQuestionAnswering,
|
||||||
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, \
|
from .modeling_xlm_roberta import (
|
||||||
XLMRobertaForMultipleChoice, XLMRobertaForTokenClassification, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
XLMRobertaModel,
|
||||||
|
XLMRobertaForMaskedLM,
|
||||||
|
XLMRobertaForSequenceClassification,
|
||||||
|
XLMRobertaForMultipleChoice,
|
||||||
|
XLMRobertaForTokenClassification,
|
||||||
|
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
|
||||||
@@ -51,7 +107,8 @@ from .file_utils import add_start_docstrings
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
|
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||||
|
(key, value)
|
||||||
for pretrained_map in [
|
for pretrained_map in [
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
@@ -67,7 +124,8 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
|
|||||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
]
|
]
|
||||||
for key, value, in pretrained_map.items())
|
for key, value, in pretrained_map.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(object):
|
class AutoModel(object):
|
||||||
@@ -98,10 +156,13 @@ class AutoModel(object):
|
|||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError("AutoModel is designed to be instantiated "
|
raise EnvironmentError(
|
||||||
|
"AutoModel is designed to be instantiated "
|
||||||
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
"`AutoModel.from_config(config)` methods.")
|
"`AutoModel.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
@@ -232,35 +293,39 @@ class AutoModel(object):
|
|||||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 't5' in pretrained_model_name_or_path:
|
if "t5" in pretrained_model_name_or_path:
|
||||||
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'distilbert' in pretrained_model_name_or_path:
|
elif "distilbert" in pretrained_model_name_or_path:
|
||||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'albert' in pretrained_model_name_or_path:
|
elif "albert" in pretrained_model_name_or_path:
|
||||||
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif "camembert" in pretrained_model_name_or_path:
|
||||||
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||||
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif "roberta" in pretrained_model_name_or_path:
|
||||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif "bert" in pretrained_model_name_or_path:
|
||||||
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||||
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'gpt2' in pretrained_model_name_or_path:
|
elif "gpt2" in pretrained_model_name_or_path:
|
||||||
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||||
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif "xlnet" in pretrained_model_name_or_path:
|
||||||
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif "xlm" in pretrained_model_name_or_path:
|
||||||
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'ctrl' in pretrained_model_name_or_path:
|
elif "ctrl" in pretrained_model_name_or_path:
|
||||||
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelWithLMHead(object):
|
class AutoModelWithLMHead(object):
|
||||||
@@ -291,10 +356,13 @@ class AutoModelWithLMHead(object):
|
|||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated "
|
raise EnvironmentError(
|
||||||
|
"AutoModelWithLMHead is designed to be instantiated "
|
||||||
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
|
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
"`AutoModelWithLMHead.from_config(config)` methods.")
|
"`AutoModelWithLMHead.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
@@ -423,35 +491,39 @@ class AutoModelWithLMHead(object):
|
|||||||
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 't5' in pretrained_model_name_or_path:
|
if "t5" in pretrained_model_name_or_path:
|
||||||
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'distilbert' in pretrained_model_name_or_path:
|
elif "distilbert" in pretrained_model_name_or_path:
|
||||||
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'albert' in pretrained_model_name_or_path:
|
elif "albert" in pretrained_model_name_or_path:
|
||||||
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif "camembert" in pretrained_model_name_or_path:
|
||||||
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||||
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif "roberta" in pretrained_model_name_or_path:
|
||||||
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif "bert" in pretrained_model_name_or_path:
|
||||||
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'openai-gpt' in pretrained_model_name_or_path:
|
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||||
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'gpt2' in pretrained_model_name_or_path:
|
elif "gpt2" in pretrained_model_name_or_path:
|
||||||
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'transfo-xl' in pretrained_model_name_or_path:
|
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||||
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif "xlnet" in pretrained_model_name_or_path:
|
||||||
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif "xlm" in pretrained_model_name_or_path:
|
||||||
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'ctrl' in pretrained_model_name_or_path:
|
elif "ctrl" in pretrained_model_name_or_path:
|
||||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForSequenceClassification(object):
|
class AutoModelForSequenceClassification(object):
|
||||||
@@ -477,10 +549,13 @@ class AutoModelForSequenceClassification(object):
|
|||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError("AutoModelForSequenceClassification is designed to be instantiated "
|
raise EnvironmentError(
|
||||||
|
"AutoModelForSequenceClassification is designed to be instantiated "
|
||||||
"using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
"using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
"`AutoModelForSequenceClassification.from_config(config)` methods.")
|
"`AutoModelForSequenceClassification.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
@@ -597,25 +672,39 @@ class AutoModelForSequenceClassification(object):
|
|||||||
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if "distilbert" in pretrained_model_name_or_path:
|
||||||
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForSequenceClassification.from_pretrained(
|
||||||
elif 'albert' in pretrained_model_name_or_path:
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
return AlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif "albert" in pretrained_model_name_or_path:
|
||||||
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return AlbertForSequenceClassification.from_pretrained(
|
||||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
return XLMRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif "camembert" in pretrained_model_name_or_path:
|
||||||
return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertForSequenceClassification.from_pretrained(
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||||
|
return XLMRobertaForSequenceClassification.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
elif "roberta" in pretrained_model_name_or_path:
|
||||||
|
return RobertaForSequenceClassification.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
elif "bert" in pretrained_model_name_or_path:
|
||||||
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif "xlnet" in pretrained_model_name_or_path:
|
||||||
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif "xlm" in pretrained_model_name_or_path:
|
||||||
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
|
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForQuestionAnswering(object):
|
class AutoModelForQuestionAnswering(object):
|
||||||
@@ -638,10 +727,13 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError("AutoModelForQuestionAnswering is designed to be instantiated "
|
raise EnvironmentError(
|
||||||
|
"AutoModelForQuestionAnswering is designed to be instantiated "
|
||||||
"using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
|
"using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
"`AutoModelForQuestionAnswering.from_config(config)` methods.")
|
"`AutoModelForQuestionAnswering.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
@@ -745,26 +837,30 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if "distilbert" in pretrained_model_name_or_path:
|
||||||
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'albert' in pretrained_model_name_or_path:
|
elif "albert" in pretrained_model_name_or_path:
|
||||||
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif "bert" in pretrained_model_name_or_path:
|
||||||
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif "xlnet" in pretrained_model_name_or_path:
|
||||||
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlm' in pretrained_model_name_or_path:
|
elif "xlm" in pretrained_model_name_or_path:
|
||||||
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path))
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
|
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForTokenClassification:
|
class AutoModelForTokenClassification:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError("AutoModelForTokenClassification is designed to be instantiated "
|
raise EnvironmentError(
|
||||||
|
"AutoModelForTokenClassification is designed to be instantiated "
|
||||||
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
"`AutoModelForTokenClassification.from_config(config)` methods.")
|
"`AutoModelForTokenClassification.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
@@ -870,18 +966,28 @@ class AutoModelForTokenClassification:
|
|||||||
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'camembert' in pretrained_model_name_or_path:
|
if "camembert" in pretrained_model_name_or_path:
|
||||||
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertForTokenClassification.from_pretrained(
|
||||||
elif 'distilbert' in pretrained_model_name_or_path:
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
)
|
||||||
elif 'xlm-roberta' in pretrained_model_name_or_path:
|
elif "distilbert" in pretrained_model_name_or_path:
|
||||||
return XLMRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForTokenClassification.from_pretrained(
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||||
|
return XLMRobertaForTokenClassification.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
elif "roberta" in pretrained_model_name_or_path:
|
||||||
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif "bert" in pretrained_model_name_or_path:
|
||||||
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif "xlnet" in pretrained_model_name_or_path:
|
||||||
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError(
|
||||||
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(pretrained_model_name_or_path))
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
|
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -33,27 +33,27 @@ from .file_utils import add_start_docstrings
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
||||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
||||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
||||||
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
||||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
||||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
||||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||||
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
|
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
|
||||||
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
|
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
|
||||||
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
|
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
|
||||||
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
|
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
|
||||||
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
|
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
|
||||||
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
|
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
|
||||||
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
|
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
|
||||||
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
|
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -65,8 +65,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
logger.error(
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||||
@@ -81,7 +83,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
arrays.append(array)
|
arrays.append(array)
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
name = name.split('/')
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||||
@@ -89,18 +91,18 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||||
l = re.split(r'_(\d+)', m_name)
|
l = re.split(r"_(\d+)", m_name)
|
||||||
else:
|
else:
|
||||||
l = [m_name]
|
l = [m_name]
|
||||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
if l[0] == "kernel" or l[0] == "gamma":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
elif l[0] == "output_bias" or l[0] == "beta":
|
||||||
pointer = getattr(pointer, 'bias')
|
pointer = getattr(pointer, "bias")
|
||||||
elif l[0] == 'output_weights':
|
elif l[0] == "output_weights":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif l[0] == 'squad':
|
elif l[0] == "squad":
|
||||||
pointer = getattr(pointer, 'classifier')
|
pointer = getattr(pointer, "classifier")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
pointer = getattr(pointer, l[0])
|
pointer = getattr(pointer, l[0])
|
||||||
@@ -110,9 +112,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
if len(l) >= 2:
|
if len(l) >= 2:
|
||||||
num = int(l[1])
|
num = int(l[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
if m_name[-11:] == '_embeddings':
|
if m_name[-11:] == "_embeddings":
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, "weight")
|
||||||
elif m_name == 'kernel':
|
elif m_name == "kernel":
|
||||||
array = np.transpose(array)
|
array = np.transpose(array)
|
||||||
try:
|
try:
|
||||||
assert pointer.shape == array.shape
|
assert pointer.shape == array.shape
|
||||||
@@ -157,6 +159,7 @@ BertLayerNorm = torch.nn.LayerNorm
|
|||||||
class BertEmbeddings(nn.Module):
|
class BertEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings.
|
"""Construct the embeddings from word, position and token_type embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertEmbeddings, self).__init__()
|
super(BertEmbeddings, self).__init__()
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
||||||
@@ -199,7 +202,8 @@ class BertSelfAttention(nn.Module):
|
|||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||||
|
)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
@@ -217,7 +221,14 @@ class BertSelfAttention(nn.Module):
|
|||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -307,8 +318,17 @@ class BertAttention(nn.Module):
|
|||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
def forward(
|
||||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
):
|
||||||
|
self_outputs = self.self(
|
||||||
|
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
|
||||||
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
@@ -353,13 +373,22 @@ class BertLayer(nn.Module):
|
|||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
):
|
||||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
cross_attention_outputs = self.crossattention(
|
||||||
|
attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
|
||||||
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
@@ -376,14 +405,23 @@ class BertEncoder(nn.Module):
|
|||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
|
||||||
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
@@ -440,9 +478,7 @@ class BertLMPredictionHead(nn.Module):
|
|||||||
|
|
||||||
# The output weights are the same as the input embeddings, but there is
|
# The output weights are the same as the input embeddings, but there is
|
||||||
# an output-only bias for each token.
|
# an output-only bias for each token.
|
||||||
self.decoder = nn.Linear(config.hidden_size,
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
config.vocab_size,
|
|
||||||
bias=False)
|
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
@@ -488,6 +524,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_bert
|
load_tf_weights = load_tf_weights_in_bert
|
||||||
@@ -581,8 +618,12 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertModel(BertPreTrainedModel):
|
class BertModel(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -612,6 +653,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertModel, self).__init__(config)
|
super(BertModel, self).__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -636,8 +678,17 @@ class BertModel(BertPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
|
def forward(
|
||||||
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
):
|
||||||
""" Forward pass on the Model.
|
""" Forward pass on the Model.
|
||||||
|
|
||||||
The model can behave as an encoder (with only self-attention) as well
|
The model can behave as an encoder (with only self-attention) as well
|
||||||
@@ -681,12 +732,18 @@ class BertModel(BertPreTrainedModel):
|
|||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
|
causal_mask = causal_mask.to(
|
||||||
|
torch.long
|
||||||
|
) # not converting to long will cause errors with pytorch version < 1.3
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
|
raise ValueError(
|
||||||
|
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||||
|
input_shape, attention_mask.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -709,10 +766,15 @@ class BertModel(BertPreTrainedModel):
|
|||||||
elif encoder_attention_mask.dim() == 2:
|
elif encoder_attention_mask.dim() == 2:
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
|
raise ValueError(
|
||||||
encoder_attention_mask.shape))
|
"Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
|
||||||
|
encoder_hidden_shape, encoder_attention_mask.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
||||||
|
dtype=next(self.parameters()).dtype
|
||||||
|
) # fp16 compatibility
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
else:
|
else:
|
||||||
encoder_extended_attention_mask = None
|
encoder_extended_attention_mask = None
|
||||||
@@ -727,28 +789,40 @@ class BertModel(BertPreTrainedModel):
|
|||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = (
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(
|
||||||
|
dtype=next(self.parameters()).dtype
|
||||||
|
) # switch to fload if need + fp16 compatibility
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
embedding_output = self.embeddings(
|
||||||
encoder_outputs = self.encoder(embedding_output,
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask)
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||||
|
1:
|
||||||
|
] # add hidden_states and attentions if they are here
|
||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
@add_start_docstrings(
|
||||||
|
"""Bert Model with two heads on top as done during the pre-training:
|
||||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertForPreTraining(BertPreTrainedModel):
|
class BertForPreTraining(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -786,6 +860,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
prediction_scores, seq_relationship_scores = outputs[:2]
|
prediction_scores, seq_relationship_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForPreTraining, self).__init__(config)
|
super(BertForPreTraining, self).__init__(config)
|
||||||
|
|
||||||
@@ -797,20 +872,33 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
masked_lm_labels=None, next_sentence_label=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
masked_lm_labels=None,
|
||||||
|
next_sentence_label=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output, pooled_output = outputs[:2]
|
sequence_output, pooled_output = outputs[:2]
|
||||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||||
|
|
||||||
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (prediction_scores, seq_relationship_score,) + outputs[
|
||||||
|
2:
|
||||||
|
] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
@@ -822,9 +910,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
|
@add_start_docstrings(
|
||||||
BERT_START_DOCSTRING,
|
"""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
|
||||||
BERT_INPUTS_DOCSTRING)
|
)
|
||||||
class BertForMaskedLM(BertPreTrainedModel):
|
class BertForMaskedLM(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -862,6 +950,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
loss, prediction_scores = outputs[:2]
|
loss, prediction_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForMaskedLM, self).__init__(config)
|
super(BertForMaskedLM, self).__init__(config)
|
||||||
|
|
||||||
@@ -873,17 +962,30 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
masked_lm_labels=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
lm_labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask)
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.cls(sequence_output)
|
prediction_scores = self.cls(sequence_output)
|
||||||
@@ -912,9 +1014,11 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
@add_start_docstrings(
|
||||||
|
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -945,6 +1049,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
seq_relationship_scores = outputs[0]
|
seq_relationship_scores = outputs[0]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForNextSentencePrediction, self).__init__(config)
|
super(BertForNextSentencePrediction, self).__init__(config)
|
||||||
|
|
||||||
@@ -953,15 +1058,25 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
next_sentence_label=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
next_sentence_label=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@@ -976,10 +1091,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertForSequenceClassification(BertPreTrainedModel):
|
class BertForSequenceClassification(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -1011,6 +1128,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForSequenceClassification, self).__init__(config)
|
super(BertForSequenceClassification, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -1021,15 +1139,25 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
def forward(
|
||||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@@ -1051,10 +1179,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertForMultipleChoice(BertPreTrainedModel):
|
class BertForMultipleChoice(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -1087,6 +1217,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
loss, classification_scores = outputs[:2]
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForMultipleChoice, self).__init__(config)
|
super(BertForMultipleChoice, self).__init__(config)
|
||||||
|
|
||||||
@@ -1096,8 +1227,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
def forward(
|
||||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
|
|
||||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
@@ -1105,12 +1244,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@@ -1128,10 +1269,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""Bert Model with a token classification head on top (a linear layer on top of
|
||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertForTokenClassification(BertPreTrainedModel):
|
class BertForTokenClassification(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -1161,6 +1304,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForTokenClassification, self).__init__(config)
|
super(BertForTokenClassification, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -1171,15 +1315,25 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
def forward(
|
||||||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@@ -1202,10 +1356,12 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
@add_start_docstrings(
|
||||||
|
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -1247,6 +1403,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertForQuestionAnswering, self).__init__(config)
|
super(BertForQuestionAnswering, self).__init__(config)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@@ -1256,15 +1413,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
def forward(
|
||||||
start_positions=None, end_positions=None):
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch CamemBERT model. """
|
"""PyTorch CamemBERT model. """
|
||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
unicode_literals)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
|
from .modeling_roberta import (
|
||||||
|
RobertaModel,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
RobertaForMultipleChoice,
|
||||||
|
RobertaForTokenClassification,
|
||||||
|
)
|
||||||
from .configuration_camembert import CamembertConfig
|
from .configuration_camembert import CamembertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-pytorch_model.bin",
|
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -100,8 +105,12 @@ CAMEMBERT_INPUTS_DOCSTRING = r"""
|
|||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
CAMEMBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class CamembertModel(RobertaModel):
|
class CamembertModel(RobertaModel):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -149,8 +158,11 @@ class CamembertModel(RobertaModel):
|
|||||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""CamemBERT Model with a `language modeling` head on top. """,
|
@add_start_docstrings(
|
||||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
"""CamemBERT Model with a `language modeling` head on top. """,
|
||||||
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
CAMEMBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class CamembertForMaskedLM(RobertaForMaskedLM):
|
class CamembertForMaskedLM(RobertaForMaskedLM):
|
||||||
r"""
|
r"""
|
||||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
@@ -185,9 +197,12 @@ class CamembertForMaskedLM(RobertaForMaskedLM):
|
|||||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer
|
@add_start_docstrings(
|
||||||
|
"""CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer
|
||||||
on top of the pooled output) e.g. for GLUE tasks. """,
|
on top of the pooled output) e.g. for GLUE tasks. """,
|
||||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
CAMEMBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
@@ -223,9 +238,12 @@ class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
|||||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
CAMEMBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
||||||
r"""
|
r"""
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -257,9 +275,12 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
|||||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""CamemBERT Model with a token classification head on top (a linear layer on top of
|
@add_start_docstrings(
|
||||||
|
"""CamemBERT Model with a token classification head on top (a linear layer on top of
|
||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
CAMEMBERT_INPUTS_DOCSTRING,
|
||||||
|
)
|
||||||
class CamembertForTokenClassification(RobertaForTokenClassification):
|
class CamembertForTokenClassification(RobertaForTokenClassification):
|
||||||
r"""
|
r"""
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user