Merge branch 'master' into generative-finetuning
This commit is contained in:
@@ -76,7 +76,7 @@ import torch
|
|||||||
from pytorch_transformers import *
|
from pytorch_transformers import *
|
||||||
|
|
||||||
# PyTorch-Transformers has a unified API
|
# PyTorch-Transformers has a unified API
|
||||||
# for 6 transformer architectures and 27 pretrained weights.
|
# for 7 transformer architectures and 30 pretrained weights.
|
||||||
# Model | Tokenizer | Pretrained weights shortcut
|
# Model | Tokenizer | Pretrained weights shortcut
|
||||||
MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'),
|
MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'),
|
||||||
(OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),
|
(OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),
|
||||||
@@ -328,7 +328,7 @@ Breaking change in the `from_pretrained()`method:
|
|||||||
|
|
||||||
1. Models are now set in evaluation mode by default when instantiated with the `from_pretrained()` method. To train them don't forget to set them back in training mode (`model.train()`) to activate the dropout modules.
|
1. Models are now set in evaluation mode by default when instantiated with the `from_pretrained()` method. To train them don't forget to set them back in training mode (`model.train()`) to activate the dropout modules.
|
||||||
|
|
||||||
2. The additional `*input` and `**kwargs` arguments supplied to the `from_pretrained()` method used to be directly passed to the underlying model's class `__init__()` method. They are now used to update the model configuration attribute instead which can break derived model classes build based on the previous `BertForSequenceClassification` examples. We are working on a way to mitigate this breaking change in [#866](https://github.com/huggingface/pytorch-transformers/pull/866) by forwarding the the model `__init__()` method (i) the provided positional arguments and (ii) the keyword arguments which do not match any configuratoin class attributes.
|
2. The additional `*input` and `**kwargs` arguments supplied to the `from_pretrained()` method used to be directly passed to the underlying model's class `__init__()` method. They are now used to update the model configuration attribute instead which can break derived model classes build based on the previous `BertForSequenceClassification` examples. We are working on a way to mitigate this breaking change in [#866](https://github.com/huggingface/pytorch-transformers/pull/866) by forwarding the the model `__init__()` method (i) the provided positional arguments and (ii) the keyword arguments which do not match any configuration class attributes.
|
||||||
|
|
||||||
Also, while not a breaking change, the serialization methods have been standardized and you probably should switch to the new method `save_pretrained(save_directory)` if you were using any other serialization method before.
|
Also, while not a breaking change, the serialization methods have been standardized and you probably should switch to the new method `save_pretrained(save_directory)` if you were using any other serialization method before.
|
||||||
|
|
||||||
@@ -393,8 +393,8 @@ for batch in train_data:
|
|||||||
loss = model(batch)
|
loss = model(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
||||||
scheduler.step()
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,9 @@ GLUE results on dev set
|
|||||||
~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
We get the following results on the dev set of GLUE benchmark with an uncased BERT base
|
We get the following results on the dev set of GLUE benchmark with an uncased BERT base
|
||||||
model. All experiments were run on a P100 GPU with a batch size of 32.
|
model (`bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train batch size of 24. Some of
|
||||||
|
these tasks have a small dataset and training can lead to high variance in the results between different runs.
|
||||||
|
We report the median on 5 runs (with different seeds) for each of the metrics.
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:header-rows: 1
|
:header-rows: 1
|
||||||
@@ -78,31 +80,31 @@ model. All experiments were run on a P100 GPU with a batch size of 32.
|
|||||||
- Result
|
- Result
|
||||||
* - CoLA
|
* - CoLA
|
||||||
- Matthew's corr.
|
- Matthew's corr.
|
||||||
- 57.29
|
- 55.75
|
||||||
* - SST-2
|
* - SST-2
|
||||||
- accuracy
|
- accuracy
|
||||||
- 93.00
|
- 92.09
|
||||||
* - MRPC
|
* - MRPC
|
||||||
- F1/accuracy
|
- F1/accuracy
|
||||||
- 88.85/83.82
|
- 90.48/86.27
|
||||||
* - STS-B
|
* - STS-B
|
||||||
- Pearson/Spearman corr.
|
- Pearson/Spearman corr.
|
||||||
- 89.70/89.37
|
- 89.03/88.64
|
||||||
* - QQP
|
* - QQP
|
||||||
- accuracy/F1
|
- accuracy/F1
|
||||||
- 90.72/87.41
|
- 90.92/87.72
|
||||||
* - MNLI
|
* - MNLI
|
||||||
- matched acc./mismatched acc.
|
- matched acc./mismatched acc.
|
||||||
- 83.95/84.39
|
- 83.74/84.06
|
||||||
* - QNLI
|
* - QNLI
|
||||||
- accuracy
|
- accuracy
|
||||||
- 89.04
|
- 91.07
|
||||||
* - RTE
|
* - RTE
|
||||||
- accuracy
|
- accuracy
|
||||||
- 61.01
|
- 68.59
|
||||||
* - WNLI
|
* - WNLI
|
||||||
- accuracy
|
- accuracy
|
||||||
- 53.52
|
- 43.66
|
||||||
|
|
||||||
|
|
||||||
Some of these results are significantly different from the ones reported on the test set
|
Some of these results are significantly different from the ones reported on the test set
|
||||||
|
|||||||
@@ -62,6 +62,9 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``gpt2-medium`` | | 24-layer, 1024-hidden, 16-heads, 345M parameters. |
|
| | ``gpt2-medium`` | | 24-layer, 1024-hidden, 16-heads, 345M parameters. |
|
||||||
| | | | OpenAI's Medium-sized GPT-2 English model |
|
| | | | OpenAI's Medium-sized GPT-2 English model |
|
||||||
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| | ``gpt2-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters. |
|
||||||
|
| | | | OpenAI's Large-sized GPT-2 English model |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
|
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
|
||||||
| | | | English model trained on wikitext-103 |
|
| | | | English model trained on wikitext-103 |
|
||||||
@@ -72,16 +75,16 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||||
| | | | XLNet Large English model |
|
| | | | XLNet Large English model |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 1024-hidden, 8-heads |
|
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 2048-hidden, 16-heads |
|
||||||
| | | | XLM English model |
|
| | | | XLM English model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-German Multi-language model |
|
| | | | XLM English-German Multi-language model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-French Multi-language model |
|
| | | | XLM English-French Multi-language model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-enro-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-enro-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-Romanian Multi-language model |
|
| | | | XLM English-Romanian Multi-language model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||||
@@ -93,7 +96,7 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English model trained with CLM (Causal Language Modeling) |
|
| | | | XLM English model trained with CLM (Causal Language Modeling) |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-clm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-clm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
|
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
|
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
|
||||||
|
|||||||
@@ -314,15 +314,16 @@ def main():
|
|||||||
mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
|
mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
|
||||||
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
|
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
scheduler.step() # Update learning rate schedule
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
scheduler.step() # Update learning rate schedule
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# Save a trained model
|
# Save a trained model
|
||||||
if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 :
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||||
model.save_pretrained(args.output_dir)
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -507,7 +507,7 @@ def main():
|
|||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
if not os.path.exists(args.output_dir) and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 ):
|
if not os.path.exists(args.output_dir) and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
@@ -602,15 +602,16 @@ def main():
|
|||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
scheduler.step() # Update learning rate schedule
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
scheduler.step() # Update learning rate schedule
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# Save a trained model
|
# Save a trained model
|
||||||
if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||||
model.save_pretrained(args.output_dir)
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -211,10 +211,12 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(
|
||||||
|
ALL_MODELS))
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
@@ -222,9 +224,9 @@ def main():
|
|||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
help="Pretrained config name or path if not the same as model_name_or_path")
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
help="Pretrained tokenizer name or path if not the same as model_name_or_path")
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
parser.add_argument("--cache_dir", default="", type=str,
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||||
parser.add_argument("--data_subset", type=int, default=-1,
|
parser.add_argument("--data_subset", type=int, default=-1,
|
||||||
@@ -297,15 +299,15 @@ def main():
|
|||||||
|
|
||||||
args.model_type = ""
|
args.model_type = ""
|
||||||
for key in MODEL_CLASSES:
|
for key in MODEL_CLASSES:
|
||||||
if key in args.model_name.lower():
|
if key in args.model_name_or_path.lower():
|
||||||
args.model_type = key # take the first match in model types
|
args.model_type = key # take the first match in model types
|
||||||
break
|
break
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name,
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels, finetuning_task=args.task_name,
|
num_labels=num_labels, finetuning_task=args.task_name,
|
||||||
output_attentions=True)
|
output_attentions=True)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path)
|
||||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
sep_token=tokenizer.sep_token,
|
sep_token=tokenizer.sep_token,
|
||||||
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||||
pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token],
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
||||||
)
|
)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@@ -467,7 +467,7 @@ def main():
|
|||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -481,7 +481,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
# Save the trained model and the tokenizer
|
# Save the trained model and the tokenizer
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
@@ -498,7 +498,7 @@ def main():
|
|||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -422,12 +422,14 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||||||
tokens_b = tokenizer.tokenize(example.text_b)
|
tokens_b = tokenizer.tokenize(example.text_b)
|
||||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||||
# length is less than the specified length.
|
# length is less than the specified length.
|
||||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
# Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa.
|
||||||
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
special_tokens_count = 4 if sep_token_extra else 3
|
||||||
|
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
|
||||||
else:
|
else:
|
||||||
# Account for [CLS] and [SEP] with "- 2"
|
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
||||||
if len(tokens_a) > max_seq_length - 2:
|
special_tokens_count = 3 if sep_token_extra else 2
|
||||||
tokens_a = tokens_a[:(max_seq_length - 2)]
|
if len(tokens_a) > max_seq_length - special_tokens_count:
|
||||||
|
tokens_a = tokens_a[:(max_seq_length - special_tokens_count)]
|
||||||
|
|
||||||
# The convention in BERT is:
|
# The convention in BERT is:
|
||||||
# (a) For sequence pairs:
|
# (a) For sequence pairs:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
|||||||
if gpt2_config_file == "":
|
if gpt2_config_file == "":
|
||||||
config = GPT2Config()
|
config = GPT2Config()
|
||||||
else:
|
else:
|
||||||
config = GPT2Config(gpt2_config_file)
|
config = GPT2Config.from_json_file(gpt2_config_file)
|
||||||
model = GPT2Model(config)
|
model = GPT2Model(config)
|
||||||
|
|
||||||
# Load weights from numpy
|
# Load weights from numpy
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
if openai_config_file == "":
|
if openai_config_file == "":
|
||||||
config = OpenAIGPTConfig()
|
config = OpenAIGPTConfig()
|
||||||
else:
|
else:
|
||||||
config = OpenAIGPTConfig(openai_config_file)
|
config = OpenAIGPTConfig.from_json_file(openai_config_file)
|
||||||
model = OpenAIGPTModel(config)
|
model = OpenAIGPTModel(config)
|
||||||
|
|
||||||
# Load weights from numpy
|
# Load weights from numpy
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
if transfo_xl_config_file == "":
|
if transfo_xl_config_file == "":
|
||||||
config = TransfoXLConfig()
|
config = TransfoXLConfig()
|
||||||
else:
|
else:
|
||||||
config = TransfoXLConfig(transfo_xl_config_file)
|
config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
|
||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
model = TransfoXLLMHeadModel(config)
|
model = TransfoXLLMHeadModel(config)
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ from hashlib import sha256
|
|||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import requests
|
from botocore.config import Config
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -93,12 +94,15 @@ def filename_to_url(filename, cache_dir=None):
|
|||||||
return url, etag
|
return url, etag
|
||||||
|
|
||||||
|
|
||||||
def cached_path(url_or_filename, cache_dir=None):
|
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
|
||||||
"""
|
"""
|
||||||
Given something that might be a URL (or might be a local path),
|
Given something that might be a URL (or might be a local path),
|
||||||
determine which. If it's a URL, download the file and cache it, and
|
determine which. If it's a URL, download the file and cache it, and
|
||||||
return the path to the cached file. If it's already a local path,
|
return the path to the cached file. If it's already a local path,
|
||||||
make sure the file exists and then return the path.
|
make sure the file exists and then return the path.
|
||||||
|
Args:
|
||||||
|
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
||||||
|
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||||
@@ -111,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None):
|
|||||||
|
|
||||||
if parsed.scheme in ('http', 'https', 's3'):
|
if parsed.scheme in ('http', 'https', 's3'):
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
# URL, so get it from the cache (downloading if necessary)
|
||||||
return get_from_cache(url_or_filename, cache_dir)
|
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
elif os.path.exists(url_or_filename):
|
elif os.path.exists(url_or_filename):
|
||||||
# File, and it exists.
|
# File, and it exists.
|
||||||
return url_or_filename
|
return url_or_filename
|
||||||
@@ -156,24 +160,24 @@ def s3_request(func):
|
|||||||
|
|
||||||
|
|
||||||
@s3_request
|
@s3_request
|
||||||
def s3_etag(url):
|
def s3_etag(url, proxies=None):
|
||||||
"""Check ETag on S3 object."""
|
"""Check ETag on S3 object."""
|
||||||
s3_resource = boto3.resource("s3")
|
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
bucket_name, s3_path = split_s3_path(url)
|
||||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
s3_object = s3_resource.Object(bucket_name, s3_path)
|
||||||
return s3_object.e_tag
|
return s3_object.e_tag
|
||||||
|
|
||||||
|
|
||||||
@s3_request
|
@s3_request
|
||||||
def s3_get(url, temp_file):
|
def s3_get(url, temp_file, proxies=None):
|
||||||
"""Pull a file directly from S3."""
|
"""Pull a file directly from S3."""
|
||||||
s3_resource = boto3.resource("s3")
|
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
bucket_name, s3_path = split_s3_path(url)
|
||||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||||
|
|
||||||
|
|
||||||
def http_get(url, temp_file):
|
def http_get(url, temp_file, proxies=None):
|
||||||
req = requests.get(url, stream=True)
|
req = requests.get(url, stream=True, proxies=proxies)
|
||||||
content_length = req.headers.get('Content-Length')
|
content_length = req.headers.get('Content-Length')
|
||||||
total = int(content_length) if content_length is not None else None
|
total = int(content_length) if content_length is not None else None
|
||||||
progress = tqdm(unit="B", total=total)
|
progress = tqdm(unit="B", total=total)
|
||||||
@@ -184,7 +188,7 @@ def http_get(url, temp_file):
|
|||||||
progress.close()
|
progress.close()
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(url, cache_dir=None):
|
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding dataset in the local cache.
|
Given a URL, look for the corresponding dataset in the local cache.
|
||||||
If it's not there, download it. Then return the path to the cached file.
|
If it's not there, download it. Then return the path to the cached file.
|
||||||
@@ -201,10 +205,10 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
|
|
||||||
# Get eTag to add to filename, if it exists.
|
# Get eTag to add to filename, if it exists.
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
etag = s3_etag(url)
|
etag = s3_etag(url, proxies=proxies)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
response = requests.head(url, allow_redirects=True)
|
response = requests.head(url, allow_redirects=True, proxies=proxies)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
etag = None
|
etag = None
|
||||||
else:
|
else:
|
||||||
@@ -227,17 +231,17 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
if matching_files:
|
if matching_files:
|
||||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||||
|
|
||||||
if not os.path.exists(cache_path):
|
if not os.path.exists(cache_path) or force_download:
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
# Download to temporary file, then copy to cache dir once finished.
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||||
with tempfile.NamedTemporaryFile() as temp_file:
|
with tempfile.NamedTemporaryFile() as temp_file:
|
||||||
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||||
|
|
||||||
# GET file object
|
# GET file object
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
s3_get(url, temp_file)
|
s3_get(url, temp_file, proxies=proxies)
|
||||||
else:
|
else:
|
||||||
http_get(url, temp_file)
|
http_get(url, temp_file, proxies=proxies)
|
||||||
|
|
||||||
# we are copying the file before closing it, so flush to avoid truncation
|
# we are copying the file before closing it, so flush to avoid truncation
|
||||||
temp_file.flush()
|
temp_file.flush()
|
||||||
|
|||||||
@@ -578,6 +578,8 @@ BERT_START_DOCSTRING = r""" The BERT model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BERT_INPUTS_DOCSTRING = r"""
|
BERT_INPUTS_DOCSTRING = r"""
|
||||||
@@ -598,6 +600,9 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
``token_type_ids: 0 0 0 0 0 0 0``
|
``token_type_ids: 0 0 0 0 0 0 0``
|
||||||
|
|
||||||
|
Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
|
||||||
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
||||||
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||||
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||||
""" Load tf checkpoints in a pytorch model
|
""" Load tf checkpoints in a pytorch model
|
||||||
@@ -383,11 +385,15 @@ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
@@ -612,7 +618,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
|
@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
|
||||||
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
||||||
The language modeling head has its weights tied to the input embeddings,
|
The language modeling head has its weights tied to the input embeddings,
|
||||||
the classification head takes as input the input of a specified classification token index in the intput sequence).
|
the classification head takes as input the input of a specified classification token index in the input sequence).
|
||||||
""", GPT2_START_DOCSTRING)
|
""", GPT2_START_DOCSTRING)
|
||||||
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||||
r""" Inputs:
|
r""" Inputs:
|
||||||
|
|||||||
@@ -397,11 +397,15 @@ OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~pytorch_transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
@@ -602,7 +606,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
|
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
|
||||||
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
||||||
The language modeling head has its weights tied to the input embeddings,
|
The language modeling head has its weights tied to the input embeddings,
|
||||||
the classification head takes as input the input of a specified classification token index in the intput sequence).
|
the classification head takes as input the input of a specified classification token index in the input sequence).
|
||||||
""", OPENAI_GPT_START_DOCSTRING)
|
""", OPENAI_GPT_START_DOCSTRING)
|
||||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||||
r""" Inputs:
|
r""" Inputs:
|
||||||
|
|||||||
@@ -90,7 +90,8 @@ ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
config (:class:`~pytorch_transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
||||||
model.
|
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ROBERTA_INPUTS_DOCSTRING = r"""
|
ROBERTA_INPUTS_DOCSTRING = r"""
|
||||||
@@ -109,6 +110,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
|
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
|
||||||
the ``add_special_tokens`` parameter set to ``True``.
|
the ``add_special_tokens`` parameter set to ``True``.
|
||||||
|
|
||||||
|
RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
|
|||||||
@@ -928,12 +928,16 @@ TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
||||||
Inputs:
|
Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
|
||||||
|
the right or on the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ class PretrainedConfig(object):
|
|||||||
r""" Base class for all configuration classes.
|
r""" Base class for all configuration classes.
|
||||||
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
|
||||||
|
It only affects the model's configuration.
|
||||||
|
|
||||||
Class attributes (overridden by derived classes):
|
Class attributes (overridden by derived classes):
|
||||||
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
||||||
|
|
||||||
@@ -121,6 +125,13 @@ class PretrainedConfig(object):
|
|||||||
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
||||||
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
return_unused_kwargs: (`optional`) bool:
|
return_unused_kwargs: (`optional`) bool:
|
||||||
|
|
||||||
- If False, then this function returns just the final configuration object.
|
- If False, then this function returns just the final configuration object.
|
||||||
@@ -142,6 +153,8 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
|
force_download = kwargs.pop('force_download', False)
|
||||||
|
proxies = kwargs.pop('proxies', None)
|
||||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
@@ -152,7 +165,7 @@ class PretrainedConfig(object):
|
|||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -396,6 +409,13 @@ class PreTrainedModel(nn.Module):
|
|||||||
Path to a directory in which a downloaded pre-trained model
|
Path to a directory in which a downloaded pre-trained model
|
||||||
configuration should be cached if the standard cache should not be used.
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
output_loading_info: (`optional`) boolean:
|
output_loading_info: (`optional`) boolean:
|
||||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
|
||||||
@@ -420,6 +440,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
state_dict = kwargs.pop('state_dict', None)
|
state_dict = kwargs.pop('state_dict', None)
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
from_tf = kwargs.pop('from_tf', False)
|
from_tf = kwargs.pop('from_tf', False)
|
||||||
|
force_download = kwargs.pop('force_download', False)
|
||||||
|
proxies = kwargs.pop('proxies', None)
|
||||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
@@ -427,6 +449,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
config, model_kwargs = cls.config_class.from_pretrained(
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args,
|
pretrained_model_name_or_path, *model_args,
|
||||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||||
|
force_download=force_download,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -449,7 +472,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
@@ -416,12 +416,18 @@ XLM_START_DOCSTRING = r""" The XLM model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.XLMConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~pytorch_transformers.XLMConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
XLM_INPUTS_DOCSTRING = r"""
|
XLM_INPUTS_DOCSTRING = r"""
|
||||||
Inputs:
|
Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
XLM is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
|
||||||
Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
@@ -434,8 +440,10 @@ XLM_INPUTS_DOCSTRING = r"""
|
|||||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||||
**langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
A parallel sequence of tokens to be used to indicate the language of each token in the input.
|
A parallel sequence of tokens to be used to indicate the language of each token in the input.
|
||||||
Indices are selected in the pre-trained language vocabulary,
|
Indices are languages ids which can be obtained from the language names by using two conversion mappings
|
||||||
i.e. in the range ``[0, config.n_langs - 1[``.
|
provided in the configuration of the model (only provided for multilingual models).
|
||||||
|
More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
|
||||||
|
the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
|
||||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Mask to avoid performing attention on padding token indices.
|
Mask to avoid performing attention on padding token indices.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
|
|||||||
@@ -647,12 +647,16 @@ XLNET_START_DOCSTRING = r""" The XLNet model was proposed in
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
XLNET_INPUTS_DOCSTRING = r"""
|
XLNET_INPUTS_DOCSTRING = r"""
|
||||||
Inputs:
|
Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
XLNet is a model with relative position embeddings so you can either pad the inputs on
|
||||||
|
the right or on the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import uuid
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import logging
|
import logging
|
||||||
@@ -527,7 +528,7 @@ class ConfigTester(object):
|
|||||||
|
|
||||||
def create_and_test_config_to_json_file(self):
|
def create_and_test_config_to_json_file(self):
|
||||||
config_first = self.config_class(**self.inputs_dict)
|
config_first = self.config_class(**self.inputs_dict)
|
||||||
json_file_path = "/tmp/config.json"
|
json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
|
||||||
config_first.to_json_file(json_file_path)
|
config_first.to_json_file(json_file_path)
|
||||||
config_second = self.config_class.from_json_file(json_file_path)
|
config_second = self.config_class.from_json_file(json_file_path)
|
||||||
os.remove(json_file_path)
|
os.remove(json_file_path)
|
||||||
|
|||||||
@@ -187,6 +187,8 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
index = 0
|
index = 0
|
||||||
if os.path.isdir(vocab_path):
|
if os.path.isdir(vocab_path):
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
else:
|
||||||
|
vocab_file = vocab_path
|
||||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
if index != token_index:
|
if index != token_index:
|
||||||
|
|||||||
@@ -45,17 +45,20 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
{
|
{
|
||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
||||||
|
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
|
||||||
},
|
},
|
||||||
'merges_file':
|
'merges_file':
|
||||||
{
|
{
|
||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||||
|
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'gpt2': 1024,
|
'gpt2': 1024,
|
||||||
'gpt2-medium': 1024,
|
'gpt2-medium': 1024,
|
||||||
|
'gpt2-large': 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
|||||||
@@ -89,8 +89,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
import spacy
|
from spacy.lang.en import English
|
||||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
_nlp = English()
|
||||||
|
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
|
||||||
self.fix_text = ftfy.fix_text
|
self.fix_text = ftfy.fix_text
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||||
|
|||||||
@@ -193,6 +193,13 @@ class PreTrainedTokenizer(object):
|
|||||||
cache_dir: (`optional`) string:
|
cache_dir: (`optional`) string:
|
||||||
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
||||||
|
|
||||||
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
||||||
@@ -223,6 +230,8 @@ class PreTrainedTokenizer(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
|
force_download = kwargs.pop('force_download', False)
|
||||||
|
proxies = kwargs.pop('proxies', None)
|
||||||
|
|
||||||
s3_models = list(cls.max_model_input_sizes.keys())
|
s3_models = list(cls.max_model_input_sizes.keys())
|
||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
@@ -283,7 +292,7 @@ class PreTrainedTokenizer(object):
|
|||||||
if file_path is None:
|
if file_path is None:
|
||||||
resolved_vocab_files[file_id] = None
|
resolved_vocab_files[file_id] = None
|
||||||
else:
|
else:
|
||||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
|
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in s3_models:
|
if pretrained_model_name_or_path in s3_models:
|
||||||
logger.error("Couldn't reach server to download vocabulary.")
|
logger.error("Couldn't reach server to download vocabulary.")
|
||||||
@@ -477,15 +486,45 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
Take care of added tokens.
|
Take care of added tokens.
|
||||||
"""
|
"""
|
||||||
|
def split_on_token(tok, text):
|
||||||
|
result = []
|
||||||
|
split_text = text.split(tok)
|
||||||
|
for i, sub_text in enumerate(split_text):
|
||||||
|
sub_text = sub_text.strip()
|
||||||
|
if i == 0 and not sub_text:
|
||||||
|
result += [tok]
|
||||||
|
elif i == len(split_text) - 1:
|
||||||
|
if sub_text:
|
||||||
|
result += [sub_text]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if sub_text:
|
||||||
|
result += [sub_text]
|
||||||
|
result += [tok]
|
||||||
|
return result
|
||||||
|
|
||||||
def split_on_tokens(tok_list, text):
|
def split_on_tokens(tok_list, text):
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return []
|
||||||
if not tok_list:
|
if not tok_list:
|
||||||
return self._tokenize(text, **kwargs)
|
return self._tokenize(text, **kwargs)
|
||||||
tok = tok_list[0]
|
|
||||||
split_text = text.split(tok)
|
tokenized_text = []
|
||||||
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
|
text_list = [text]
|
||||||
for sub_text in split_text), [])[:-1]
|
for tok in tok_list:
|
||||||
|
tokenized_text = []
|
||||||
|
for sub_text in text_list:
|
||||||
|
if sub_text not in self.added_tokens_encoder \
|
||||||
|
and sub_text not in self.all_special_tokens:
|
||||||
|
tokenized_text += split_on_token(tok, sub_text)
|
||||||
|
else:
|
||||||
|
tokenized_text += [sub_text]
|
||||||
|
text_list = tokenized_text
|
||||||
|
|
||||||
|
return sum((self._tokenize(token, **kwargs) if token not \
|
||||||
|
in self.added_tokens_encoder and token not in self.all_special_tokens \
|
||||||
|
else [token] for token in tokenized_text), [])
|
||||||
|
|
||||||
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
|
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
|
||||||
tokenized_text = split_on_tokens(added_tokens, text)
|
tokenized_text = split_on_tokens(added_tokens, text)
|
||||||
|
|||||||
@@ -124,8 +124,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
import spacy
|
from spacy.lang.en import English
|
||||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
_nlp = English()
|
||||||
|
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
|
||||||
self.fix_text = ftfy.fix_text
|
self.fix_text = ftfy.fix_text
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||||
|
|||||||
Reference in New Issue
Block a user