Files
HuggingFace_transformer/src/transformers/hf_argparser.py
Julien Plu f9414f7553 Tensorflow improvements (#4530)
* Better None gradients handling

* Apply Style

* Apply Style

* Create a loss class per task to compute its respective loss

* Add loss classes to the ALBERT TF models

* Add loss classes to the BERT TF models

* Add question answering and multiple choice to TF Camembert

* Remove prints

* Add multiple choice model to TF DistilBERT + loss computation

* Add question answering model to TF Electra + loss computation

* Add token classification, question answering and multiple choice models to TF Flaubert

* Add multiple choice model to TF Roberta + loss computation

* Add multiple choice model to TF XLM + loss computation

* Add multiple choice and question answering models to TF XLM-Roberta

* Add multiple choice model to TF XLNet + loss computation

* Remove unused parameters

* Add task loss classes

* Reorder TF imports + add new model classes

* Add new model classes

* Bugfix in TF T5 model

* Bugfix for TF T5 tests

* Bugfix in TF T5 model

* Fix TF T5 model tests

* Fix T5 tests + some renaming

* Fix inheritance issue in the AutoX tests

* Add tests for TF Flaubert and TF XLM Roberta

* Add tests for TF Flaubert and TF XLM Roberta

* Remove unused piece of code in the TF trainer

* bugfix and remove unused code

* Bugfix for TF 2.2

* Apply Style

* Divide TFSequenceClassificationAndMultipleChoiceLoss into their two respective name

* Apply style

* Mirror the PT Trainer in the TF one: fp16, optimizers and tb_writer as class parameter and better dataset handling

* Fix TF optimizations tests and apply style

* Remove useless parameter

* Bugfix and apply style

* Fix TF Trainer prediction

* Now the TF models return the loss such as their PyTorch couterparts

* Apply Style

* Ignore some tests output

* Take into account the SQuAD cls_index, p_mask and is_impossible parameters for the QuestionAnswering task models.

* Fix names for SQuAD data

* Apply Style

* Fix conflicts with 2.11 release

* Fix conflicts with 2.11

* Fix wrongname

* Add better documentation on the new create_optimizer function

* Fix isort

* logging_dir: use same default as PyTorch

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
2020-06-04 19:45:53 -04:00

161 lines
7.1 KiB
Python

import dataclasses
import json
import sys
from argparse import ArgumentParser
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Tuple, Union
DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
to generate arguments.
The class is designed to play well with the native argparse. In particular,
you can add more (non-dataclass backed) arguments to the parser after initialization
and you'll get the output back after parsing as an additional namespace.
"""
dataclass_types: Iterable[DataClassType]
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances
with the parsed args.
kwargs:
(Optional) Passed to `argparse.ArgumentParser()` in the regular way.
"""
super().__init__(**kwargs)
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
self.dataclass_types = dataclass_types
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
def _add_dataclass_arguments(self, dtype: DataClassType):
for field in dataclasses.fields(dtype):
field_name = f"--{field.name}"
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
# it is provided as a third-party extension mechanism.
if isinstance(field.type, str):
raise ImportError(
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
"which can be opted in from Python 3.7 with `from __future__ import annotations`."
"We will add compatibility when Python 3.9 is released."
)
typestring = str(field.type)
for prim_type in (int, float, str):
for collection in (List,):
if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
field.type = collection[prim_type]
if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
field.type = prim_type
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool:
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no-{field.name}"
kwargs["dest"] = field.name
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]
assert all(
x == kwargs["type"] for x in field.type.__args__
), "{} cannot be a List of mixed types".format(field.name)
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["required"] = True
self.add_argument(field_name, **kwargs)
def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True
) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
This relies on argparse's `ArgumentParser.parse_known_args`.
See the doc at:
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
Args:
args:
List of strings to parse. The default is taken from sys.argv.
(same as argparse.ArgumentParser)
return_remaining_strings:
If true, also return a list of remaining argument strings.
look_for_args_file:
If true, will look for a ".args" file with the same base name
as the entry point script for this process, and will append its
potential content to the command line args.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they
were passed to the initializer.abspath
- if applicable, an additional namespace for more
(non-dataclass backed) arguments added to the parser
after initialization.
- The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args)
"""
if look_for_args_file and len(sys.argv):
args_file = Path(sys.argv[0]).with_suffix(".args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)
obj = dtype(**inputs)
outputs.append(obj)
if len(namespace.__dict__) > 0:
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
return (*outputs,)
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all,
instead loading a json file and populating the dataclass types.
"""
data = json.loads(Path(json_file).read_text())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)