TF support
This commit is contained in:
@@ -24,7 +24,8 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
|
||||
cached_path, hf_bucket_url, is_remote_url)
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -257,12 +258,14 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
|
||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = pretrained_model_name_or_path
|
||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
|
||||
if from_pt:
|
||||
raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.")
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
||||
@@ -372,7 +372,8 @@ class PreTrainedModel(nn.Module):
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
|
||||
# todo do we want to support TF checkpoints here?
|
||||
if from_tf:
|
||||
raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.")
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user