AutoModelForTableQuestionAnswering (#9154)
* AutoModelForTableQuestionAnswering * Update src/transformers/models/auto/modeling_auto.py * Style
This commit is contained in:
@@ -114,6 +114,13 @@ AutoModelForQuestionAnswering
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForTableQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.AutoModelForTableQuestionAnswering
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFAutoModel
|
TFAutoModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -358,6 +358,7 @@ if is_torch_available():
|
|||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
@@ -370,6 +371,7 @@ if is_torch_available():
|
|||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ deps = {
|
|||||||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||||
"sphinx": "sphinx==3.2.1",
|
"sphinx": "sphinx==3.2.1",
|
||||||
"starlette": "starlette",
|
"starlette": "starlette",
|
||||||
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
|
"tensorflow-cpu": "tensorflow-cpu>=2.0",
|
||||||
"tensorflow": "tensorflow>=2.0,<2.4",
|
"tensorflow": "tensorflow>=2.0",
|
||||||
"timeout-decorator": "timeout-decorator",
|
"timeout-decorator": "timeout-decorator",
|
||||||
"tokenizers": "tokenizers==0.9.4",
|
"tokenizers": "tokenizers==0.9.4",
|
||||||
"torch": "torch>=1.0",
|
"torch": "torch>=1.0",
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
|||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
@@ -43,6 +44,7 @@ if is_torch_available():
|
|||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -467,6 +467,12 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||||||
(FunnelConfig, FunnelForQuestionAnswering),
|
(FunnelConfig, FunnelForQuestionAnswering),
|
||||||
(LxmertConfig, LxmertForQuestionAnswering),
|
(LxmertConfig, LxmertForQuestionAnswering),
|
||||||
(MPNetConfig, MPNetForQuestionAnswering),
|
(MPNetConfig, MPNetForQuestionAnswering),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Table Question Answering mapping
|
||||||
(TapasConfig, TapasForQuestionAnswering),
|
(TapasConfig, TapasForQuestionAnswering),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -1384,6 +1390,106 @@ class AutoModelForQuestionAnswering:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForTableQuestionAnswering:
|
||||||
|
r"""
|
||||||
|
This is a generic model class that will be instantiated as one of the model classes of the library---with a table
|
||||||
|
question answering head---when created with the when created with the
|
||||||
|
:meth:`~transformers.AutoModeForTableQuestionAnswering.from_pretrained` class method or the
|
||||||
|
:meth:`~transformers.AutoModelForTableQuestionAnswering.from_config` class method.
|
||||||
|
|
||||||
|
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"AutoModelForQuestionAnswering is designed to be instantiated "
|
||||||
|
"using the `AutoModelForTableQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`AutoModelForTableQuestionAnswering.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, use_model_types=False)
|
||||||
|
def from_config(cls, config):
|
||||||
|
r"""
|
||||||
|
Instantiates one of the model classes of the library---with a table question answering head---from a
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||||
|
model's configuration. Use :meth:`~transformers.AutoModelForTableQuestionAnswering.from_pretrained` to load
|
||||||
|
the model weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
|
||||||
|
List options
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
|
||||||
|
>>> # Download configuration from huggingface.co and cache.
|
||||||
|
>>> config = AutoConfig.from_pretrained('google/tapas-base-finetuned-wtq')
|
||||||
|
>>> model = AutoModelForTableQuestionAnswering.from_config(config)
|
||||||
|
"""
|
||||||
|
if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys():
|
||||||
|
return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)](config)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING)
|
||||||
|
@add_start_docstrings(
|
||||||
|
"Instantiate one of the model classes of the library---with a table question answering head---from a "
|
||||||
|
"pretrained model.",
|
||||||
|
AUTO_MODEL_PRETRAINED_DOCSTRING,
|
||||||
|
)
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
|
||||||
|
|
||||||
|
>>> # Download model and configuration from huggingface.co and cache.
|
||||||
|
>>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq')
|
||||||
|
|
||||||
|
>>> # Update configuration during loading
|
||||||
|
>>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq', output_attentions=True)
|
||||||
|
>>> model.config.output_attentions
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
>>> config = AutoConfig.from_json_file('./tf_model/tapas_tf_checkpoint.json')
|
||||||
|
>>> model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/tapas_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys():
|
||||||
|
return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForTokenClassification:
|
class AutoModelForTokenClassification:
|
||||||
r"""
|
r"""
|
||||||
This is a generic model class that will be instantiated as one of the model classes of the library---with a token
|
This is a generic model class that will be instantiated as one of the model classes of the library---with a token
|
||||||
|
|||||||
@@ -303,6 +303,9 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -393,6 +396,15 @@ class AutoModelForSequenceClassification:
|
|||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForTableQuestionAnswering:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForTokenClassification:
|
class AutoModelForTokenClassification:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|||||||
@@ -17,7 +17,13 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
|
from transformers.testing_utils import (
|
||||||
|
DUMMY_UNKWOWN_IDENTIFIER,
|
||||||
|
SMALL_MODEL_IDENTIFIER,
|
||||||
|
require_scatter,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -30,6 +36,7 @@ if is_torch_available():
|
|||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
@@ -44,6 +51,8 @@ if is_torch_available():
|
|||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
T5Config,
|
T5Config,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
|
TapasConfig,
|
||||||
|
TapasForQuestionAnswering,
|
||||||
)
|
)
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
@@ -52,6 +61,7 @@ if is_torch_available():
|
|||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
@@ -59,6 +69,7 @@ if is_torch_available():
|
|||||||
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.gpt2.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -168,6 +179,21 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForQuestionAnswering)
|
self.assertIsInstance(model, BertForQuestionAnswering)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_scatter
|
||||||
|
def test_table_question_answering_model_from_pretrained(self):
|
||||||
|
for model_name in TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST[5:6]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, TapasConfig)
|
||||||
|
|
||||||
|
model = AutoModelForTableQuestionAnswering.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelForTableQuestionAnswering.from_pretrained(
|
||||||
|
model_name, output_loading_info=True
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, TapasForQuestionAnswering)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_token_classification_model_from_pretrained(self):
|
def test_token_classification_model_from_pretrained(self):
|
||||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
@@ -200,6 +226,7 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ from transformers import (
|
|||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
@@ -436,7 +436,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||||
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values():
|
||||||
inputs_dict["labels"] = torch.zeros(
|
inputs_dict["labels"] = torch.zeros(
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user