MarianMTModel.from_pretrained('Helsinki-NLP/opus-marian-en-de') (#3908)

Co-Authored-By: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
Sam Shleifer
2020-04-28 18:22:37 -04:00
committed by GitHub
parent d714dfeaa8
commit 847e7f3379
12 changed files with 887 additions and 26 deletions

View File

@@ -18,6 +18,7 @@ import math
import random
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
@@ -125,7 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
elif isinstance(module, SinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@@ -250,10 +253,16 @@ class BartEncoder(nn.Module):
self.max_source_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
if config.static_position_embeddings:
self.embed_positions = SinusoidalPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx,
)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim)
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
@@ -422,13 +431,18 @@ class BartDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx,
)
if config.static_position_embeddings:
self.embed_positions = SinusoidalPositionalEmbedding(
config.max_position_embeddings, config.d_model, config.pad_token_id
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx,
)
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
def forward(
@@ -470,7 +484,7 @@ class BartDecoder(nn.Module):
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
# assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids) * self.embed_scale
x += positions
@@ -859,6 +873,22 @@ class BartForConditionalGeneration(PretrainedBartModel):
super().__init__(config)
base_model = BartModel(config)
self.model = base_model
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
old_num_tokens = self.model.shared.num_embeddings
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self.model.shared = new_embeddings
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens))
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
@@ -923,8 +953,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states=decoder_cached_states,
use_cache=use_cache,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in lm_labels?
@@ -957,6 +987,18 @@ class BartForConditionalGeneration(PretrainedBartModel):
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past
@@ -1061,3 +1103,39 @@ class BartForSequenceClassification(PretrainedBartModel):
outputs = (loss,) + outputs
return outputs
class SinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions, embedding_dim, padding_idx=None):
super().__init__(num_positions, embedding_dim)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
The cos features are in the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False
return out
@torch.no_grad()
def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions)