Merge pull request #1628 from huggingface/tfglue
run_tf_glue works with all tasks
This commit is contained in:
@@ -1,29 +1,47 @@
|
|||||||
import os
|
import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets
|
import tensorflow_datasets
|
||||||
from transformers import BertTokenizer, TFBertForSequenceClassification, glue_convert_examples_to_features, BertForSequenceClassification
|
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig, glue_convert_examples_to_features, BertForSequenceClassification, glue_processors
|
||||||
|
|
||||||
# script parameters
|
# script parameters
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
||||||
USE_XLA = False
|
USE_XLA = False
|
||||||
USE_AMP = False
|
USE_AMP = False
|
||||||
|
EPOCHS = 3
|
||||||
|
|
||||||
|
TASK = "mrpc"
|
||||||
|
|
||||||
|
if TASK == "sst-2":
|
||||||
|
TFDS_TASK = "sst2"
|
||||||
|
elif TASK == "sts-b":
|
||||||
|
TFDS_TASK = "stsb"
|
||||||
|
else:
|
||||||
|
TFDS_TASK = TASK
|
||||||
|
|
||||||
|
num_labels = len(glue_processors[TASK]().get_labels())
|
||||||
|
print(num_labels)
|
||||||
|
|
||||||
tf.config.optimizer.set_jit(USE_XLA)
|
tf.config.optimizer.set_jit(USE_XLA)
|
||||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
||||||
|
|
||||||
# Load tokenizer and model from pretrained model/vocabulary
|
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
|
||||||
|
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
|
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased', config=config)
|
||||||
|
|
||||||
# Load dataset via TensorFlow Datasets
|
# Load dataset via TensorFlow Datasets
|
||||||
data, info = tensorflow_datasets.load('glue/mrpc', with_info=True)
|
data, info = tensorflow_datasets.load(f'glue/{TFDS_TASK}', with_info=True)
|
||||||
train_examples = info.splits['train'].num_examples
|
train_examples = info.splits['train'].num_examples
|
||||||
|
|
||||||
|
# MNLI expects either validation_matched or validation_mismatched
|
||||||
valid_examples = info.splits['validation'].num_examples
|
valid_examples = info.splits['validation'].num_examples
|
||||||
|
|
||||||
# Prepare dataset for GLUE as a tf.data.Dataset instance
|
# Prepare dataset for GLUE as a tf.data.Dataset instance
|
||||||
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, 'mrpc')
|
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, TASK)
|
||||||
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, 'mrpc')
|
|
||||||
|
# MNLI expects either validation_matched or validation_mismatched
|
||||||
|
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, TASK)
|
||||||
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
||||||
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||||
|
|
||||||
@@ -32,7 +50,13 @@ opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
|||||||
if USE_AMP:
|
if USE_AMP:
|
||||||
# loss scaling is currently required when using mixed precision
|
# loss scaling is currently required when using mixed precision
|
||||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
|
||||||
|
|
||||||
|
if num_labels == 1:
|
||||||
|
loss = tf.keras.losses.MeanSquaredError()
|
||||||
|
else:
|
||||||
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
||||||
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||||
|
|
||||||
@@ -40,7 +64,7 @@ model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
|||||||
train_steps = train_examples//BATCH_SIZE
|
train_steps = train_examples//BATCH_SIZE
|
||||||
valid_steps = valid_examples//EVAL_BATCH_SIZE
|
valid_steps = valid_examples//EVAL_BATCH_SIZE
|
||||||
|
|
||||||
history = model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
|
history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
|
||||||
validation_data=valid_dataset, validation_steps=valid_steps)
|
validation_data=valid_dataset, validation_steps=valid_steps)
|
||||||
|
|
||||||
# Save TF2 model
|
# Save TF2 model
|
||||||
@@ -57,6 +81,9 @@ sentence_2 = 'His findings were not compatible with this research.'
|
|||||||
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
|
inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
|
||||||
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
|
inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')
|
||||||
|
|
||||||
|
del inputs_1["special_tokens_mask"]
|
||||||
|
del inputs_2["special_tokens_mask"]
|
||||||
|
|
||||||
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
|
||||||
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
|
||||||
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
|
print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||||||
logger.info("Writing example %d" % (ex_index))
|
logger.info("Writing example %d" % (ex_index))
|
||||||
if is_tf_dataset:
|
if is_tf_dataset:
|
||||||
example = processor.get_example_from_tensor_dict(example)
|
example = processor.get_example_from_tensor_dict(example)
|
||||||
|
example = processor.tfds_map(example)
|
||||||
|
|
||||||
inputs = tokenizer.encode_plus(
|
inputs = tokenizer.encode_plus(
|
||||||
example.text_a,
|
example.text_a,
|
||||||
|
|||||||
@@ -107,6 +107,13 @@ class DataProcessor(object):
|
|||||||
"""Gets the list of labels for this data set."""
|
"""Gets the list of labels for this data set."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def tfds_map(self, example):
|
||||||
|
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
|
||||||
|
This method converts examples to the correct format."""
|
||||||
|
if len(self.get_labels()) > 1:
|
||||||
|
example.label = self.get_labels()[int(example.label)]
|
||||||
|
return example
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _read_tsv(cls, input_file, quotechar=None):
|
def _read_tsv(cls, input_file, quotechar=None):
|
||||||
"""Reads a tab separated value file."""
|
"""Reads a tab separated value file."""
|
||||||
|
|||||||
Reference in New Issue
Block a user