fix add_start_docstrings on python 2 (removed)
This commit is contained in:
@@ -15,17 +15,20 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
from io import open
|
||||
|
||||
import six
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
@@ -36,11 +39,18 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = ''.join(docstr) + fn.__doc__
|
||||
return fn
|
||||
return docstring_decorator
|
||||
if not six.PY2:
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = ''.join(docstr) + fn.__doc__
|
||||
return fn
|
||||
return docstring_decorator
|
||||
else:
|
||||
# Not possible to update class docstrings on python2
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
return fn
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
class PretrainedConfig(object):
|
||||
|
||||
Reference in New Issue
Block a user