From ac7d5f67a2d7e0a086c37ea9107d95901cc678cb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 11 May 2020 16:38:07 +0200 Subject: [PATCH] [Reformer] Add Enwiki8 Reformer Model - Adapt convert script (#4282) * adapt convert script * update convert script * finish * fix marian pretrained docs --- docs/source/pretrained_models.rst | 9 ++++++--- src/transformers/configuration_reformer.py | 3 ++- ...convert_reformer_trax_checkpoint_to_pytorch.py | 15 ++++++--------- src/transformers/modeling_reformer.py | 3 ++- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 64ba6167e4..0d82b68127 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -296,9 +296,12 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``DialoGPT-large`` | | 36-layer, 1280-hidden, 20-heads, 774M parameters | | | | | Trained on English text: 147M conversation-like exchanges extracted from Reddit. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | -| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky | +| Reformer | ``reformer-enwik8`` | | 12-layer, 1024-hidden, 8-heads, 149M parameters | +| | | | Trained on English Wikipedia data - enwik8. | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters | +| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky. | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. | -| | | | (see `model list `_ | +| | | | (see `model list `_) | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index 4e7227bfa0..572fa58fac 100644 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -24,7 +24,8 @@ from .configuration_utils import PretrainedConfig logger = logging.getLogger(__name__) REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json" + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json", + "google/reformer-enwik8": "https://cdn.huggingface.co/google/reformer-enwik8/config.json", } diff --git a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py index 3d88e3d78d..5e6dee7c08 100755 --- a/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py +++ b/src/transformers/convert_reformer_trax_checkpoint_to_pytorch.py @@ -93,7 +93,7 @@ def set_block_weights_in_torch(weights, torch_block, hidden_size): set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size) # intermediate weighs - intermediate_weights = weights[2][0][2][2] + intermediate_weights = weights[2][0][1][2] # Chunked Feed Forward if len(intermediate_weights) == 4: @@ -145,19 +145,16 @@ def set_model_weights_in_torch(weights, torch_model, hidden_size): position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) trax_layer_weights = weights[5] - assert len(torch_model_reformer.encoder.layers) * 4 + 1 == len( + assert len(torch_model_reformer.encoder.layers) * 4 == len( trax_layer_weights ), "HF and trax model do not have the same number of layers" for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers): block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] set_block_weights_in_torch(block_weights, layer, hidden_size) - # output weights - out_weights = weights[6] - # output layer norm - layer_norm_out_weight = np.asarray(out_weights[0][0]) - layer_norm_out_bias = np.asarray(out_weights[0][1]) + layer_norm_out_weight = np.asarray(weights[7][0]) + layer_norm_out_bias = np.asarray(weights[7][1]) set_param( torch_model_reformer.encoder.layer_norm, torch.tensor(layer_norm_out_weight), @@ -165,8 +162,8 @@ def set_model_weights_in_torch(weights, torch_model, hidden_size): ) # output embeddings - output_embed_weights = np.asarray(out_weights[2][0]) - output_embed_bias = np.asarray(out_weights[2][1]) + output_embed_weights = np.asarray(weights[9][0]) + output_embed_bias = np.asarray(weights[9][1]) set_param( torch_model.lm_head.decoder, torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index dc671b0ab5..39a6d7d951 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -36,7 +36,8 @@ from .modeling_utils import PreTrainedModel, apply_chunking_to_forward logger = logging.getLogger(__name__) REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = { - "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/pytorch_model.bin" + "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/pytorch_model.bin", + "google/reformer-enwik8": "https://cdn.huggingface.co/google/reformer-enwik8/pytorch_model.bin", }