Fix F821 flake8 warning (x47).
Ignore warnings related to Python 2, because it's going away soon.
This commit is contained in:
@@ -108,7 +108,7 @@ def read_swag_examples(input_file, is_training=True):
|
|||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
line = list(unicode(cell, "utf-8") for cell in line)
|
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
|
|
||||||
if is_training and lines[0][-1] != "label":
|
if is_training and lines[0][-1] != "label":
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ def main():
|
|||||||
# Batch size == 1. to add more examples please use num_return_sequences > 1
|
# Batch size == 1. to add more examples please use num_return_sequences > 1
|
||||||
generated_sequence = output_sequences[0].tolist()
|
generated_sequence = output_sequences[0].tolist()
|
||||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||||
text = text[: t.find(args.stop_token) if args.stop_token else None]
|
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||||
|
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ class SwagProcessor(DataProcessor):
|
|||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
line = list(unicode(cell, "utf-8") for cell in line)
|
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|||||||
@@ -68,6 +68,14 @@ TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
#
|
#
|
||||||
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
TFXxxAttention = tf.keras.layers.Layer
|
||||||
|
|
||||||
|
TFXxxIntermediate = tf.keras.layers.Layer
|
||||||
|
|
||||||
|
TFXxxOutput = tf.keras.layers.Layer
|
||||||
|
|
||||||
|
|
||||||
class TFXxxLayer(tf.keras.layers.Layer):
|
class TFXxxLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super(TFXxxLayer, self).__init__(**kwargs)
|
super(TFXxxLayer, self).__init__(**kwargs)
|
||||||
@@ -316,6 +324,9 @@ class TFXxxModel(TFXxxPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
TFXxxMLMHead = tf.keras.layers.Layer
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
"""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -135,6 +135,14 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
#
|
#
|
||||||
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
XxxAttention = nn.Module
|
||||||
|
|
||||||
|
XxxIntermediate = nn.Module
|
||||||
|
|
||||||
|
XxxOutput = nn.Module
|
||||||
|
|
||||||
|
|
||||||
class XxxLayer(nn.Module):
|
class XxxLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XxxLayer, self).__init__()
|
super(XxxLayer, self).__init__()
|
||||||
@@ -160,6 +168,16 @@ class XxxLayer(nn.Module):
|
|||||||
# pointers for your model and the weights initialization
|
# pointers for your model and the weights initialization
|
||||||
# method if its not fully covered by PreTrainedModel's default method
|
# method if its not fully covered by PreTrainedModel's default method
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
XxxLayerNorm = torch.nn.LayerNorm
|
||||||
|
|
||||||
|
XxxEmbeddings = nn.Module
|
||||||
|
|
||||||
|
XxxEncoder = nn.Module
|
||||||
|
|
||||||
|
XxxPooler = nn.Module
|
||||||
|
|
||||||
|
|
||||||
class XxxPreTrainedModel(PreTrainedModel):
|
class XxxPreTrainedModel(PreTrainedModel):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.hf_api import HfApi, HfFolder, HTTPError
|
from transformers.hf_api import HfApi, HfFolder, HTTPError
|
||||||
@@ -96,8 +97,7 @@ class LogoutCommand(BaseUserCommand):
|
|||||||
|
|
||||||
|
|
||||||
class ListObjsCommand(BaseUserCommand):
|
class ListObjsCommand(BaseUserCommand):
|
||||||
def tabulate(self, rows, headers):
|
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||||
# type: (List[List[Union[str, int]]], List[str]) -> str
|
|
||||||
"""
|
"""
|
||||||
Inspired by:
|
Inspired by:
|
||||||
stackoverflow.com/a/8356620/593036
|
stackoverflow.com/a/8356620/593036
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class DataProcessor(object):
|
|||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
line = list(unicode(cell, "utf-8") for cell in line)
|
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|||||||
@@ -419,7 +419,7 @@ def get_from_cache(
|
|||||||
with open(meta_path, "w") as meta_file:
|
with open(meta_path, "w") as meta_file:
|
||||||
output_string = json.dumps(meta)
|
output_string = json.dumps(meta)
|
||||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||||
output_string = unicode(output_string, "utf-8") # The beauty of python 2
|
output_string = unicode(output_string, "utf-8") # noqa: F821
|
||||||
meta_file.write(output_string)
|
meta_file.write(output_string)
|
||||||
|
|
||||||
return cache_path
|
return cache_path
|
||||||
|
|||||||
@@ -14,8 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import six
|
import six
|
||||||
@@ -93,7 +95,7 @@ class HfApi:
|
|||||||
return d["user"]
|
return d["user"]
|
||||||
|
|
||||||
def logout(self, token):
|
def logout(self, token):
|
||||||
# type: (...) -> void
|
# type: (...) -> None
|
||||||
"""
|
"""
|
||||||
Call HF API to log out.
|
Call HF API to log out.
|
||||||
"""
|
"""
|
||||||
@@ -135,8 +137,7 @@ class HfApi:
|
|||||||
pf.close()
|
pf.close()
|
||||||
return urls.access
|
return urls.access
|
||||||
|
|
||||||
def list_objs(self, token):
|
def list_objs(self, token) -> List[S3Obj]:
|
||||||
# type: (...) -> List[S3Obj]
|
|
||||||
"""
|
"""
|
||||||
Call HF API to list all stored files for user.
|
Call HF API to list all stored files for user.
|
||||||
"""
|
"""
|
||||||
@@ -156,9 +157,7 @@ class TqdmProgressFileReader:
|
|||||||
for implementation details.
|
for implementation details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, f: io.BufferedReader):
|
||||||
self, f # type: io.BufferedReader
|
|
||||||
):
|
|
||||||
self.f = f
|
self.f = f
|
||||||
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
||||||
self.pbar = tqdm(total=self.total_size, leave=False)
|
self.pbar = tqdm(total=self.total_size, leave=False)
|
||||||
|
|||||||
@@ -339,7 +339,9 @@ class BertIntermediate(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertIntermediate, self).__init__()
|
super(BertIntermediate, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
if isinstance(config.hidden_act, str) or (
|
||||||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||||
|
):
|
||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
@@ -459,7 +461,9 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertPredictionHeadTransform, self).__init__()
|
super(BertPredictionHeadTransform, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
if isinstance(config.hidden_act, str) or (
|
||||||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||||
|
):
|
||||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
|
|||||||
@@ -311,7 +311,9 @@ class TFAlbertLayer(tf.keras.layers.Layer):
|
|||||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
|
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
if isinstance(config.hidden_act, str) or (
|
||||||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||||
|
):
|
||||||
self.activation = ACT2FN[config.hidden_act]
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.activation = config.hidden_act
|
self.activation = config.hidden_act
|
||||||
@@ -452,7 +454,9 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.Dense(
|
||||||
config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||||
)
|
)
|
||||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
if isinstance(config.hidden_act, str) or (
|
||||||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||||
|
):
|
||||||
self.activation = ACT2FN[config.hidden_act]
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.activation = config.hidden_act
|
self.activation = config.hidden_act
|
||||||
|
|||||||
@@ -690,9 +690,9 @@ class TFAutoModelForQuestionAnswering(object):
|
|||||||
elif isinstance(config, BertConfig):
|
elif isinstance(config, BertConfig):
|
||||||
return TFBertForQuestionAnswering(config)
|
return TFBertForQuestionAnswering(config)
|
||||||
elif isinstance(config, XLNetConfig):
|
elif isinstance(config, XLNetConfig):
|
||||||
return TFXLNetForQuestionAnswering(config)
|
raise NotImplementedError("TFXLNetForQuestionAnswering isn't implemented")
|
||||||
elif isinstance(config, XLMConfig):
|
elif isinstance(config, XLMConfig):
|
||||||
return TFXLMForQuestionAnswering(config)
|
raise NotImplementedError("TFXLMForQuestionAnswering isn't implemented")
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -315,7 +315,9 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.Dense(
|
||||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||||
)
|
)
|
||||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
if isinstance(config.hidden_act, str) or (
|
||||||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||||
|
):
|
||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
@@ -420,7 +422,9 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
|||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.Dense(
|
||||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||||
)
|
)
|
||||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
if isinstance(config.hidden_act, str) or (
|
||||||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||||
|
):
|
||||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
if isinstance(config.ff_activation, str) or (
|
if isinstance(config.ff_activation, str) or (
|
||||||
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)
|
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
|
||||||
):
|
):
|
||||||
self.activation_function = ACT2FN[config.ff_activation]
|
self.activation_function = ACT2FN[config.ff_activation]
|
||||||
else:
|
else:
|
||||||
@@ -483,7 +483,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
if dtype is not None and dtype != tf.float32:
|
if dtype is not None and dtype != tf.float32:
|
||||||
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
|
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
|
||||||
if self.clamp_len > 0:
|
if self.clamp_len > 0:
|
||||||
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
|
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
|
||||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||||
|
|
||||||
return pos_emb
|
return pos_emb
|
||||||
|
|||||||
@@ -431,7 +431,7 @@ class XLNetFeedForward(nn.Module):
|
|||||||
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
|
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
if isinstance(config.ff_activation, str) or (
|
if isinstance(config.ff_activation, str) or (
|
||||||
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)
|
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
|
||||||
):
|
):
|
||||||
self.activation_function = ACT2FN[config.ff_activation]
|
self.activation_function = ACT2FN[config.ff_activation]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
for special_tok in tokenizer.all_special_tokens:
|
for special_tok in tokenizer.all_special_tokens:
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
self.assertIsInstance(special_tok, unicode)
|
self.assertIsInstance(special_tok, unicode) # noqa: F821
|
||||||
else:
|
else:
|
||||||
self.assertIsInstance(special_tok, str)
|
self.assertIsInstance(special_tok, str)
|
||||||
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
text = self.preprocess_text(text)
|
text = self.preprocess_text(text)
|
||||||
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
|
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
|
||||||
if six.PY2 and isinstance(text, unicode):
|
if six.PY2 and isinstance(text, unicode): # noqa: F821
|
||||||
text = text.encode("utf-8")
|
text = text.encode("utf-8")
|
||||||
|
|
||||||
if not sample:
|
if not sample:
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def bytes_to_unicode():
|
|||||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
"""
|
"""
|
||||||
_chr = unichr if sys.version_info[0] == 2 else chr
|
_chr = unichr if sys.version_info[0] == 2 else chr # noqa: F821
|
||||||
bs = (
|
bs = (
|
||||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,10 +36,10 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
# import cPickle as pickle
|
import cPickle as pickle
|
||||||
# else:
|
else:
|
||||||
# import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -252,10 +252,10 @@ class PreTrainedTokenizer(object):
|
|||||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||||
if key == "additional_special_tokens":
|
if key == "additional_special_tokens":
|
||||||
assert isinstance(value, (list, tuple)) and all(
|
assert isinstance(value, (list, tuple)) and all(
|
||||||
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value
|
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value # noqa: F821
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) # noqa: F821
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -567,7 +567,7 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
to_add_tokens = []
|
to_add_tokens = []
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) # noqa: F821
|
||||||
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
|
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
|
||||||
token = token.lower()
|
token = token.lower()
|
||||||
if (
|
if (
|
||||||
@@ -650,11 +650,11 @@ class PreTrainedTokenizer(object):
|
|||||||
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
||||||
if key == "additional_special_tokens":
|
if key == "additional_special_tokens":
|
||||||
assert isinstance(value, (list, tuple)) and all(
|
assert isinstance(value, (list, tuple)) and all(
|
||||||
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value
|
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value # noqa: F821
|
||||||
)
|
)
|
||||||
added_tokens += self.add_tokens(value)
|
added_tokens += self.add_tokens(value)
|
||||||
else:
|
else:
|
||||||
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) # noqa: F821
|
||||||
added_tokens += self.add_tokens([value])
|
added_tokens += self.add_tokens([value])
|
||||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
@@ -746,7 +746,7 @@ class PreTrainedTokenizer(object):
|
|||||||
if tokens is None:
|
if tokens is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): # noqa: F821
|
||||||
return self._convert_token_to_id_with_added_voc(tokens)
|
return self._convert_token_to_id_with_added_voc(tokens)
|
||||||
|
|
||||||
ids = []
|
ids = []
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
text = self.preprocess_text(text)
|
text = self.preprocess_text(text)
|
||||||
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
|
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
|
||||||
if six.PY2 and isinstance(text, unicode):
|
if six.PY2 and isinstance(text, unicode): # noqa: F821
|
||||||
text = text.encode("utf-8")
|
text = text.encode("utf-8")
|
||||||
|
|
||||||
if not sample:
|
if not sample:
|
||||||
|
|||||||
Reference in New Issue
Block a user