Merge pull request #1735 from ondewo/tf-do-not-use-gpu-on-import
Do not use GPU when importing transformers
This commit is contained in:
@@ -51,7 +51,15 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
config_class = None
|
config_class = None
|
||||||
pretrained_model_archive_map = {}
|
pretrained_model_archive_map = {}
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return tf.constant(DUMMY_INPUTS)
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user