From 83272a3853fe8bfad055ab043d432e7eebdbaae3 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 25 Mar 2020 11:10:20 -0400 Subject: [PATCH] Experiment w/ dataclasses (including Py36) (#3423) * [ci] Also run test_examples in py37 (will revert at the end of the experiment) * InputExample: use immutable dataclass * [deps] Install dataclasses for Py<3.7 * [skip ci] Revert "[ci] Also run test_examples in py37" This reverts commit d29afd9959786b77759b0b8fa4e6b4335b952015. --- setup.py | 2 ++ src/transformers/data/processors/utils.py | 25 +++++++++-------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index 008fe59931..c1d1d3f8cd 100644 --- a/setup.py +++ b/setup.py @@ -97,6 +97,8 @@ setup( install_requires=[ "numpy", "tokenizers == 0.5.2", + # dataclasses for Python versions that don't have it + "dataclasses;python_version<'3.7'", # accessing files from S3 directly "boto3", # filesystem locks e.g. to prevent parallel downloads diff --git a/src/transformers/data/processors/utils.py b/src/transformers/data/processors/utils.py index 0c31c6ce96..82fec40223 100644 --- a/src/transformers/data/processors/utils.py +++ b/src/transformers/data/processors/utils.py @@ -16,8 +16,11 @@ import copy import csv +import dataclasses import json import logging +from dataclasses import dataclass +from typing import Optional from ...file_utils import is_tf_available, is_torch_available @@ -25,7 +28,8 @@ from ...file_utils import is_tf_available, is_torch_available logger = logging.getLogger(__name__) -class InputExample(object): +@dataclass(frozen=True) +class InputExample: """ A single training/test example for simple sequence classification. @@ -39,23 +43,14 @@ class InputExample(object): specified for train and dev examples, but not for test examples. """ - def __init__(self, guid, text_a, text_b=None, label=None): - self.guid = guid - self.text_a = text_a - self.text_b = text_b - self.label = label - - def __repr__(self): - return str(self.to_json_string()) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output + guid: str + text_a: str + text_b: Optional[str] = None + label: Optional[str] = None def to_json_string(self): """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + return json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" class InputFeatures(object):