MarianMTModel.from_pretrained('Helsinki-NLP/opus-marian-en-de') (#3908)

Co-Authored-By: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
Sam Shleifer
2020-04-28 18:22:37 -04:00
committed by GitHub
parent d714dfeaa8
commit 847e7f3379
12 changed files with 887 additions and 26 deletions

View File

@@ -1530,18 +1530,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
def _force_token_ids_generation(self, scores, token_ids) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
@staticmethod
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)