Adds LlamaForQuestionAnswering class in modeling_llama.py along with AutoModel Support (#28777)
* This is a test commit * testing commit * final commit with some changes * Removed copy statement * Fixed formatting issues * Fixed error added past_key_values in the forward method * Fixed a trailing whitespace. Damn the formatting rules are strict * Added the copy statement
This commit is contained in:
@@ -44,6 +44,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
CodeLlamaTokenizer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
@@ -278,7 +279,11 @@ class LlamaModelTester:
|
||||
|
||||
@require_torch
|
||||
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@@ -286,6 +291,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
"text-classification": LlamaForSequenceClassification,
|
||||
"text-generation": LlamaForCausalLM,
|
||||
"zero-shot": LlamaForSequenceClassification,
|
||||
"question-answering": LlamaForQuestionAnswering,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
||||
Reference in New Issue
Block a user