From b2c477fc6d2fd7e5295f9edc3743eec32d2df924 Mon Sep 17 00:00:00 2001 From: Minghao Li Date: Mon, 10 Jan 2022 22:28:03 +0800 Subject: [PATCH] support the trocr small models (#14893) * support the trocr small models * resolve conflict * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/model_doc/trocr.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fix unexpected indent in processing_trocr.py * Update src/transformers/models/trocr/processing_trocr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * update the docstring of processing_trocr * remove extra space Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- docs/source/model_doc/trocr.mdx | 6 +++--- .../models/trocr/processing_trocr.py | 20 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/docs/source/model_doc/trocr.mdx b/docs/source/model_doc/trocr.mdx index b8cdbaeb6f..494895c1a8 100644 --- a/docs/source/model_doc/trocr.mdx +++ b/docs/source/model_doc/trocr.mdx @@ -55,9 +55,9 @@ Tips: TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of [`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. -The [`ViTFeatureExtractor`] class is responsible for preprocessing the input image and -[`RobertaTokenizer`] decodes the generated target tokens to the target string. The -[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`] and [`RobertaTokenizer`] +The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and +[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The +[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`] into a single instance to both extract the input features and decode the predicted token ids. - Step-by-step Optical Character Recognition (OCR) diff --git a/src/transformers/models/trocr/processing_trocr.py b/src/transformers/models/trocr/processing_trocr.py index 3166cbae20..6022e54d22 100644 --- a/src/transformers/models/trocr/processing_trocr.py +++ b/src/transformers/models/trocr/processing_trocr.py @@ -20,22 +20,24 @@ from contextlib import contextmanager from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.models.roberta.tokenization_roberta import RobertaTokenizer from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast +from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer +from transformers.models.xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast -from ..auto.feature_extraction_auto import AutoFeatureExtractor +from transformers import AutoTokenizer, AutoFeatureExtractor class TrOCRProcessor: r""" Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor. - [`TrOCRProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`RobertaTokenizer`]. See the + [`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the [`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information. Args: - feature_extractor ([`AutoFeatureExtractor`]): - An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input. - tokenizer ([`RobertaTokenizer`]): - An instance of [`RobertaTokenizer`]. The tokenizer is a required input. + feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]): + An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input. + tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]): + An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input. """ def __init__(self, feature_extractor, tokenizer): @@ -43,9 +45,9 @@ class TrOCRProcessor: raise ValueError( f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}" ) - if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))): + if not isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, XLMRobertaTokenizerFast)): raise ValueError( - f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}" + f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__} or {XLMRobertaTokenizer.__class__} or {XLMRobertaTokenizerFast.__class__}, but is {type(tokenizer)}" ) self.feature_extractor = feature_extractor @@ -103,7 +105,7 @@ class TrOCRProcessor: [`PreTrainedTokenizer`] """ feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) - tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)