From 9f75565ea8243ec685c3e5dd08a63e8f78af9d0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 8 Nov 2019 15:48:31 +0100 Subject: [PATCH] setup training --- requirements.txt | 2 -- transformers/generate/beam_search.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 060aba915d..4a3162adce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,3 @@ regex sentencepiece # For XLM sacremoses -# For ROUGE -pyrouge diff --git a/transformers/generate/beam_search.py b/transformers/generate/beam_search.py index a18d20f31a..abe3186049 100644 --- a/transformers/generate/beam_search.py +++ b/transformers/generate/beam_search.py @@ -166,7 +166,7 @@ class BeamSearch(object): for step in range(self.max_length): decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id) - kwargs_decoder["attention_mask"] = build_mask(decoder_input) + kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id) outputs = self.model.decoder(decoder_input, **kwargs_decoder) next_token_scores = outputs[0][:, -1, :].squeeze(1)