MarianMTModel.from_pretrained('Helsinki-NLP/opus-marian-en-de') (#3908)
Co-Authored-By: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
@@ -18,6 +18,7 @@ import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
@@ -125,7 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
elif isinstance(module, SinusoidalPositionalEmbedding):
|
||||
pass
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
@@ -250,10 +253,16 @@ class BartEncoder(nn.Module):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
|
||||
if config.static_position_embeddings:
|
||||
self.embed_positions = SinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx
|
||||
)
|
||||
else:
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
||||
|
||||
@@ -422,13 +431,18 @@ class BartDecoder(nn.Module):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
)
|
||||
if config.static_position_embeddings:
|
||||
self.embed_positions = SinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, config.pad_token_id
|
||||
)
|
||||
else:
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
def forward(
|
||||
@@ -470,7 +484,7 @@ class BartDecoder(nn.Module):
|
||||
if use_cache:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
# assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
@@ -859,6 +873,22 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
super().__init__(config)
|
||||
base_model = BartModel(config)
|
||||
self.model = base_model
|
||||
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
old_num_tokens = self.model.shared.num_embeddings
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self.model.shared = new_embeddings
|
||||
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
|
||||
if new_num_tokens <= old_num_tokens:
|
||||
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
||||
else:
|
||||
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens))
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -923,8 +953,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
# TODO(SS): do we need to ignore pad tokens in lm_labels?
|
||||
@@ -957,6 +987,18 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
all_but_token_ids_mask = torch.tensor(
|
||||
[x for x in range(self.config.vocab_size) if x not in token_ids],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
((enc_out, enc_mask), decoder_cached_states) = past
|
||||
@@ -1061,3 +1103,39 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions, embedding_dim, padding_idx=None):
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
if embedding_dim % 2 != 0:
|
||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter):
|
||||
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
|
||||
The cos features are in the 2nd half of the vector. [dim // 2:]
|
||||
"""
|
||||
n_pos, dim = out.shape
|
||||
position_enc = np.array(
|
||||
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
||||
)
|
||||
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
|
||||
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
||||
out.detach_()
|
||||
out.requires_grad = False
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids, use_cache=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids.shape[:2]
|
||||
if use_cache:
|
||||
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
|
||||
else:
|
||||
# starts at 0, ends at 1-seq_len
|
||||
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
||||
return super().forward(positions)
|
||||
|
||||
Reference in New Issue
Block a user