From 629bd006bfd7e6210dcc95198be9b65614e4f051 Mon Sep 17 00:00:00 2001 From: Tim Rault Date: Fri, 2 Nov 2018 17:50:17 +0100 Subject: [PATCH 1/5] Convert optimization_test.py to PyTorch --- optimization_test_pytorch.py | 45 ++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 optimization_test_pytorch.py diff --git a/optimization_test_pytorch.py b/optimization_test_pytorch.py new file mode 100644 index 0000000000..5021467d1f --- /dev/null +++ b/optimization_test_pytorch.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import optimization_pytorch as optimization +import torch +import unittest + + +class OptimizationTest(unittest.TestCase): + + def assertListAlmostEqual(self, list1, list2, tol): + self.assertEqual(len(list1), len(list2)) + for a, b in zip(list1, list2): + self.assertAlmostEqual(a, b, delta=tol) + + def test_adam(self): + w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) + x = torch.tensor([0.4, 0.2, -0.5]) + criterion = torch.nn.MSELoss(reduction='elementwise_mean') + optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100) + for _ in range(100): + # TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary + loss = criterion(x, w) / x.size(0) + loss.backward() + optimizer.step() + self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) + + +if __name__ == "__main__": + unittest.main() From 3ebf1a13c9db61c32b2d589a8823ef30485f0304 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 17:49:35 -0400 Subject: [PATCH 2/5] Fix loss computation for indexes bigger than max_seq_length. --- modeling_pytorch.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 4a8514e3a0..b227dfeb91 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -485,9 +485,22 @@ class BertForQuestionAnswering(nn.Module): start_logits, end_logits = logits.split(1, dim=-1) if start_positions is not None and end_positions is not None: - loss_fct = CrossEntropyLoss() - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) + #loss_fct = CrossEntropyLoss() + #start_loss = loss_fct(start_logits, start_positions) + #end_loss = loss_fct(end_logits, end_positions) + batch_size, seq_length = input_ids.size() + + def compute_loss(logits, positions): + max_position = positions.max().item() + one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1).zero_() + one_hot = one_hot.scatter(1, positions, 1) + one_hot = one_hot[:, :seq_length] + log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length) + loss = -torch.mean(torch.sum(one_hot*log_probs), dim = -1) + return loss + + start_loss = compute_loss(start_logits, start_positions) + end_loss = compute_loss(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 return total_loss, (start_logits, end_logits) else: From e6a710f68473da14cf9ec50f9e748cfa01a927e5 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 17:54:22 -0400 Subject: [PATCH 3/5] device --- modeling_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modeling_pytorch.py b/modeling_pytorch.py index b227dfeb91..07fb256104 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -492,7 +492,7 @@ class BertForQuestionAnswering(nn.Module): def compute_loss(logits, positions): max_position = positions.max().item() - one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1).zero_() + one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1, device=input_ids.device).zero_() one_hot = one_hot.scatter(1, positions, 1) one_hot = one_hot[:, :seq_length] log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length) From 25d5ca48e08cc3bea1bad8eb7b5f7dc19cc032ac Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 17:57:46 -0400 Subject: [PATCH 4/5] Fix scatter LopngTensor --- modeling_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 07fb256104..32ae35b74f 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -493,7 +493,7 @@ class BertForQuestionAnswering(nn.Module): def compute_loss(logits, positions): max_position = positions.max().item() one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1, device=input_ids.device).zero_() - one_hot = one_hot.scatter(1, positions, 1) + one_hot = one_hot.scatter(1, positions.cpu(), 1) # Second argument need to be LongTensor and not cuda.LongTensor one_hot = one_hot[:, :seq_length] log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length) loss = -torch.mean(torch.sum(one_hot*log_probs), dim = -1) From 72ab10399f6b2d6e338aca4ac0ebb2f16a0cfb2e Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 18:06:21 -0400 Subject: [PATCH 5/5] Fix loss Please review @thomwolf but i think this is equivqlent (and it mimics the loss computation of the original loss) --- modeling_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 32ae35b74f..8d8c1aeed0 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -492,9 +492,9 @@ class BertForQuestionAnswering(nn.Module): def compute_loss(logits, positions): max_position = positions.max().item() - one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1, device=input_ids.device).zero_() + one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1).zero_() one_hot = one_hot.scatter(1, positions.cpu(), 1) # Second argument need to be LongTensor and not cuda.LongTensor - one_hot = one_hot[:, :seq_length] + one_hot = one_hot[:, :seq_length].to(input_ids.device) log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length) loss = -torch.mean(torch.sum(one_hot*log_probs), dim = -1) return loss