Add OPTForQuestionAnswering (#19402)
* Add `OPTForQuestionAnswering` - added `OPTForQuestionAnswering` class based on `BloomForQuestionAnswering` - added `OPTForQuestionAnswering` in common tests - all common tests pass - make fixup done * added docstrings for OPTForQuestionAnswering * Fix docstrings for OPTForQuestionAnswering
This commit is contained in:
@@ -32,7 +32,13 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel
|
||||
from transformers import (
|
||||
GPT2Tokenizer,
|
||||
OPTForCausalLM,
|
||||
OPTForQuestionAnswering,
|
||||
OPTForSequenceClassification,
|
||||
OPTModel,
|
||||
)
|
||||
|
||||
|
||||
def prepare_opt_inputs_dict(
|
||||
@@ -178,7 +184,11 @@ class OPTModelTester:
|
||||
|
||||
@require_torch
|
||||
class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(OPTModel, OPTForCausalLM, OPTForSequenceClassification, OPTForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
|
||||
is_encoder_decoder = False
|
||||
fx_compatible = True
|
||||
|
||||
Reference in New Issue
Block a user