TF support
This commit is contained in:
@@ -24,7 +24,8 @@ import os
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
|
||||||
|
cached_path, hf_bucket_url, is_remote_url)
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -257,12 +258,14 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
|
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
|
||||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
|
||||||
pretrained_model_name_or_path))
|
pretrained_model_name_or_path))
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
|
||||||
|
if from_pt:
|
||||||
|
raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.")
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -372,7 +372,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
|
||||||
# todo do we want to support TF checkpoints here?
|
if from_tf:
|
||||||
|
raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.")
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user