This commit is contained in:
thomwolf
2019-04-03 11:01:01 +02:00
parent 1d8c232324
commit 19666dcb3b

View File

@@ -91,6 +91,8 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
pointer = getattr(pointer, 'bias') pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights': elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, 'weight')
elif l[0] == 'squad':
pointer = getattr(pointer, 'classifier')
else: else:
try: try:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])