From 19666dcb3bee3e379f1458e295869957aac8590c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 3 Apr 2019 11:01:01 +0200 Subject: [PATCH] Should fix #438 --- pytorch_pretrained_bert/modeling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 938636142f..2736e34d7f 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -91,6 +91,8 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path): pointer = getattr(pointer, 'bias') elif l[0] == 'output_weights': pointer = getattr(pointer, 'weight') + elif l[0] == 'squad': + pointer = getattr(pointer, 'classifier') else: try: pointer = getattr(pointer, l[0])