support the trocr small models (#14893)
* support the trocr small models * resolve conflict * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fix unexpected indent in processing_trocr.py * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * update the docstring of processing_trocr * remove extra space Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -55,9 +55,9 @@ Tips:
|
|||||||
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
|
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
|
||||||
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||||
|
|
||||||
The [`ViTFeatureExtractor`] class is responsible for preprocessing the input image and
|
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
|
||||||
[`RobertaTokenizer`] decodes the generated target tokens to the target string. The
|
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
|
||||||
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`] and [`RobertaTokenizer`]
|
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]
|
||||||
into a single instance to both extract the input features and decode the predicted token ids.
|
into a single instance to both extract the input features and decode the predicted token ids.
|
||||||
|
|
||||||
- Step-by-step Optical Character Recognition (OCR)
|
- Step-by-step Optical Character Recognition (OCR)
|
||||||
|
|||||||
@@ -20,22 +20,24 @@ from contextlib import contextmanager
|
|||||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||||
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
|
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
|
||||||
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||||
|
from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
from transformers.models.xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||||
|
|
||||||
from ..auto.feature_extraction_auto import AutoFeatureExtractor
|
from transformers import AutoTokenizer, AutoFeatureExtractor
|
||||||
|
|
||||||
|
|
||||||
class TrOCRProcessor:
|
class TrOCRProcessor:
|
||||||
r"""
|
r"""
|
||||||
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
|
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
|
||||||
|
|
||||||
[`TrOCRProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`RobertaTokenizer`]. See the
|
[`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the
|
||||||
[`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information.
|
[`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
feature_extractor ([`AutoFeatureExtractor`]):
|
feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]):
|
||||||
An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input.
|
An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input.
|
||||||
tokenizer ([`RobertaTokenizer`]):
|
tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]):
|
||||||
An instance of [`RobertaTokenizer`]. The tokenizer is a required input.
|
An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, feature_extractor, tokenizer):
|
def __init__(self, feature_extractor, tokenizer):
|
||||||
@@ -43,9 +45,9 @@ class TrOCRProcessor:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
|
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
|
||||||
)
|
)
|
||||||
if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))):
|
if not isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, XLMRobertaTokenizerFast)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
|
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__} or {XLMRobertaTokenizer.__class__} or {XLMRobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feature_extractor = feature_extractor
|
self.feature_extractor = feature_extractor
|
||||||
@@ -103,7 +105,7 @@ class TrOCRProcessor:
|
|||||||
[`PreTrainedTokenizer`]
|
[`PreTrainedTokenizer`]
|
||||||
"""
|
"""
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user