diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 924faef458..9f28e01169 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -25,9 +25,12 @@ from tests.utils import require_tf, require_torch def merge_model_tokenizer_mappings( - model_mapping: "Dict[PretrainedConfig, Union[PreTrainedModel, TFPreTrainedModel]]", # noqa: F821 - tokenizer_mapping: "Dict[PretrainedConfig, Tuple[PreTrainedTokenizer, PreTrainedTokenizerFast]]", # noqa: F821 -) -> "Dict[Union[PreTrainedTokenizer, PreTrainedTokenizerFast], Tuple[PretrainedConfig, Union[PreTrainedModel, TFPreTrainedModel]]]": # noqa: F821 + model_mapping: Dict["PretrainedConfig", Union["PreTrainedModel", "TFPreTrainedModel"]], + tokenizer_mapping: Dict["PretrainedConfig", Tuple["PreTrainedTokenizer", "PreTrainedTokenizerFast"]], +) -> Dict[ + Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], + Tuple["PretrainedConfig", Union["PreTrainedModel", "TFPreTrainedModel"]], +]: configurations = list(model_mapping.keys()) model_tokenizer_mapping = OrderedDict([])