Fixed label datatype for STS-B (#6492)

* fixed label datatype for sts-b

* naming update

* make style

* make style
This commit is contained in:
Ali Modarressi
2020-08-18 16:39:39 +04:30
committed by GitHub
parent 12d7624199
commit 5a81195ea9

View File

@@ -79,6 +79,7 @@ if is_tf_available():
processor = glue_processors[task]() processor = glue_processors[task]()
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples] examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task) features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
label_type = tf.float32 if task == "sts-b" else tf.int64
def gen(): def gen():
for ex in features: for ex in features:
@@ -90,7 +91,7 @@ if is_tf_available():
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
gen, gen,
({k: tf.int32 for k in input_names}, tf.int64), ({k: tf.int32 for k in input_names}, label_type),
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])), ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
) )