From 5a81195ea9f4d06b3638b2a76b3e195cd240835d Mon Sep 17 00:00:00 2001 From: Ali Modarressi Date: Tue, 18 Aug 2020 16:39:39 +0430 Subject: [PATCH] Fixed label datatype for STS-B (#6492) * fixed label datatype for sts-b * naming update * make style * make style --- src/transformers/data/processors/glue.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/data/processors/glue.py b/src/transformers/data/processors/glue.py index bc28cdc3df..e8e0cd21d4 100644 --- a/src/transformers/data/processors/glue.py +++ b/src/transformers/data/processors/glue.py @@ -79,6 +79,7 @@ if is_tf_available(): processor = glue_processors[task]() examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples] features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task) + label_type = tf.float32 if task == "sts-b" else tf.int64 def gen(): for ex in features: @@ -90,7 +91,7 @@ if is_tf_available(): return tf.data.Dataset.from_generator( gen, - ({k: tf.int32 for k in input_names}, tf.int64), + ({k: tf.int32 for k in input_names}, label_type), ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])), )