Replace legacy tensor.Tensor with torch.tensor/torch.empty (#12027)

* Replace legacy torch.Tensor constructor with torch.{tensor, empty}

* Remove torch.Tensor in examples
This commit is contained in:
Mario Šaško
2021-06-08 05:58:38 -07:00
committed by GitHub
parent e33085d648
commit f5eec0d8e9
13 changed files with 26 additions and 22 deletions

View File

@@ -222,7 +222,7 @@ class GenerativeQAModule(BaseTransformer):
decoder_start_token_id = generator.config.decoder_start_token_id
decoder_input_ids = (
torch.cat(
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
[torch.tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
dim=1,
)
if target_ids.shape[0] < self.target_lens["train"]