updating CLI
This commit is contained in:
committed by
Morgan Funtowicz
parent
7c1697562a
commit
31a3a73ee3
@@ -3,6 +3,8 @@ from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.commands.user import UserCommands
|
||||
from transformers.commands.train import TrainCommand
|
||||
from transformers.commands.convert import ConvertCommand
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||
@@ -11,6 +13,8 @@ if __name__ == '__main__':
|
||||
# Register commands
|
||||
ServeCommand.register_subcommand(commands_parser)
|
||||
UserCommands.register_subcommand(commands_parser)
|
||||
TrainCommand.register_subcommand(commands_parser)
|
||||
ConvertCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -25,7 +25,6 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
SingleSentenceClassificationProcessor,
|
||||
convert_examples_to_features,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels,
|
||||
xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
|
||||
@@ -66,6 +65,9 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
|
||||
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
# Pipelines
|
||||
from .pipeline import TextClassificationPipeline
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||
|
||||
@@ -3,14 +3,11 @@ from argparse import ArgumentParser, Namespace
|
||||
from logging import getLogger
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers import (AutoTokenizer, is_tf_available, is_torch_available,
|
||||
SingleSentenceClassificationProcessor,
|
||||
convert_examples_to_features)
|
||||
if is_tf_available():
|
||||
from transformers import TFAutoModelForSequenceClassification as SequenceClassifModel
|
||||
elif is_torch_available():
|
||||
from transformers import AutoModelForSequenceClassification as SequenceClassifModel
|
||||
else:
|
||||
from transformers import (is_tf_available, is_torch_available,
|
||||
TextClassificationPipeline,
|
||||
SingleSentenceClassificationProcessor as Processor)
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||
|
||||
# TF training parameters
|
||||
@@ -35,16 +32,18 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
|
||||
|
||||
train_parser.add_argument('--train_data', type=str, required=True,
|
||||
help="path to train (and optionally evaluation) dataset as a csv with "
|
||||
"tab separated labels and sentences.")
|
||||
|
||||
train_parser.add_argument('--column_label', type=int, default=0,
|
||||
help='Column of the dataset csv file with example labels.')
|
||||
train_parser.add_argument('--column_text', type=int, default=1,
|
||||
help='Column of the dataset csv file with example texts.')
|
||||
train_parser.add_argument('--column_id', type=int, default=2,
|
||||
help='Column of the dataset csv file with example ids.')
|
||||
train_parser.add_argument('--skip_first_row', action='store_true',
|
||||
help='Skip the first row of the csv file (headers).')
|
||||
|
||||
train_parser.add_argument('--validation_data', type=str, default='',
|
||||
help='path to validation dataset.')
|
||||
@@ -74,39 +73,38 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
|
||||
self.framework = 'tf' if is_tf_available() else 'torch'
|
||||
|
||||
os.makedirs(args.output)
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
assert os.path.isdir(args.output)
|
||||
self.output = args.output
|
||||
|
||||
self.column_label = args.column_label
|
||||
self.column_text = args.column_text
|
||||
self.column_id = args.column_id
|
||||
|
||||
self.logger.info('Loading model {}'.format(args.model_name))
|
||||
self.model_name = args.model_name
|
||||
self.pipeline = AutoTokenizer.from_pretrained(args.model_name)
|
||||
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model))
|
||||
if args.task == 'text_classification':
|
||||
self.model = SequenceClassifModel.from_pretrained(args.model_name)
|
||||
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
||||
elif args.task == 'token_classification':
|
||||
raise NotImplementedError
|
||||
elif args.task == 'question_answering':
|
||||
raise NotImplementedError
|
||||
|
||||
self.logger.info('Loading dataset from {}'.format(args.train_data))
|
||||
dataset = SingleSentenceClassificationProcessor.create_from_csv(args.train_data)
|
||||
num_data_samples = len(dataset)
|
||||
self.train_dataset = Processor.create_from_csv(args.train_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row)
|
||||
self.valid_dataset = None
|
||||
if args.validation_data:
|
||||
self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
|
||||
self.valid_dataset = SingleSentenceClassificationProcessor.create_from_csv(args.validation_data)
|
||||
self.num_valid_samples = len(self.valid_dataset)
|
||||
self.train_dataset = dataset
|
||||
self.num_train_samples = num_data_samples
|
||||
else:
|
||||
assert 0.0 < args.validation_split < 1.0, "--validation_split should be between 0.0 and 1.0"
|
||||
self.num_valid_samples = num_data_samples * args.validation_split
|
||||
self.num_train_samples = num_data_samples - self.num_valid_samples
|
||||
self.train_dataset = dataset[self.num_train_samples]
|
||||
self.valid_dataset = dataset[self.num_valid_samples]
|
||||
self.valid_dataset = Processor.create_from_csv(args.validation_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row)
|
||||
|
||||
self.validation_split = args.validation_split
|
||||
self.train_batch_size = args.train_batch_size
|
||||
self.valid_batch_size = args.valid_batch_size
|
||||
self.learning_rate = args.learning_rate
|
||||
@@ -121,34 +119,13 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
raise NotImplementedError
|
||||
|
||||
def run_tf(self):
|
||||
import tensorflow as tf
|
||||
self.pipeline.fit(self.train_dataset,
|
||||
validation_data=self.valid_dataset,
|
||||
validation_split=self.validation_split,
|
||||
learning_rate=self.learning_rate,
|
||||
adam_epsilon=self.adam_epsilon,
|
||||
train_batch_size=self.train_batch_size,
|
||||
valid_batch_size=self.valid_batch_size)
|
||||
|
||||
tf.config.optimizer.set_jit(USE_XLA)
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
||||
|
||||
# Prepare dataset as a tf.train_data.Dataset instance
|
||||
self.logger.info('Tokenizing and processing dataset')
|
||||
train_dataset = self.train_dataset.get_features(self.tokenizer)
|
||||
valid_dataset = self.valid_dataset.get_features(self.tokenizer)
|
||||
train_dataset = train_dataset.shuffle(128).batch(self.train_batch_size).repeat(-1)
|
||||
valid_dataset = valid_dataset.batch(self.valid_batch_size)
|
||||
|
||||
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
||||
opt = tf.keras.optimizers.Adam(learning_rate=args.learning_rate, epsilon=self.adam_epsilon)
|
||||
if USE_AMP:
|
||||
# loss scaling is currently required when using mixed precision
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
||||
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||
|
||||
# Train and evaluate using tf.keras.Model.fit()
|
||||
train_steps = self.num_train_samples//self.train_batch_size
|
||||
valid_steps = self.num_valid_samples//self.valid_batch_size
|
||||
|
||||
self.logger.info('Training model')
|
||||
history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
|
||||
validation_data=valid_dataset, validation_steps=valid_steps)
|
||||
|
||||
# Save trained model
|
||||
self.model.save_pretrained(self.output)
|
||||
# Save trained pipeline
|
||||
self.pipeline.save_pretrained(self.output)
|
||||
|
||||
@@ -122,14 +122,30 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
return self.examples[idx]
|
||||
|
||||
@classmethod
|
||||
def create_from_csv(cls, file_name, **kwargs):
|
||||
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
|
||||
column_id=None, skip_first_row=False, **kwargs):
|
||||
processor = cls(**kwargs)
|
||||
processor.add_examples_from_csv(file_name)
|
||||
processor.add_examples_from_csv(file_name,
|
||||
split_name=split_name,
|
||||
column_label=column_label,
|
||||
column_text=column_text,
|
||||
column_id=column_id,
|
||||
skip_first_row=skip_first_row,
|
||||
overwrite_labels=True,
|
||||
overwrite_examples=True)
|
||||
return processor
|
||||
|
||||
@classmethod
|
||||
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
|
||||
processor = cls(**kwargs)
|
||||
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
||||
return processor
|
||||
|
||||
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
|
||||
overwrite_labels=False, overwrite_examples=False):
|
||||
skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
|
||||
lines = self._read_tsv(file_name)
|
||||
if skip_first_row:
|
||||
lines = lines[1:]
|
||||
texts = []
|
||||
labels = []
|
||||
ids = []
|
||||
@@ -144,15 +160,21 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
|
||||
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
|
||||
|
||||
def add_examples(self, texts, labels, ids=None, overwrite_labels=False, overwrite_examples=False):
|
||||
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
|
||||
overwrite_labels=False, overwrite_examples=False):
|
||||
assert labels is None or len(texts_or_text_and_labels) == len(labels)
|
||||
assert ids is None or len(texts_or_text_and_labels) == len(ids)
|
||||
if ids is None:
|
||||
ids = [None] * len(texts)
|
||||
assert len(texts) == len(labels)
|
||||
assert len(texts) == len(ids)
|
||||
|
||||
ids = [None] * len(texts_or_text_and_labels)
|
||||
if labels is None:
|
||||
labels = [None] * len(texts_or_text_and_labels)
|
||||
examples = []
|
||||
added_labels = set()
|
||||
for (text, label, guid) in zip(texts, labels, ids):
|
||||
for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
|
||||
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
|
||||
text, label = text_or_text_and_label
|
||||
else:
|
||||
text = text_or_text_and_label
|
||||
added_labels.add(label)
|
||||
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
|
||||
|
||||
@@ -170,12 +192,6 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
|
||||
return self.examples
|
||||
|
||||
@classmethod
|
||||
def create_from_examples(cls, texts, labels, **kwargs):
|
||||
processor = cls(**kwargs)
|
||||
processor.add_examples(texts, labels)
|
||||
return processor
|
||||
|
||||
def get_features(self,
|
||||
tokenizer,
|
||||
max_length=None,
|
||||
@@ -204,6 +220,8 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
a list of task-specific ``InputFeatures`` which can be fed to the model.
|
||||
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = tokenizer.max_len
|
||||
|
||||
label_map = {label: i for i, label in enumerate(self.labels)}
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ import logging
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
import h5py
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||
@@ -206,6 +208,9 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||
|
||||
@@ -229,6 +234,7 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
if config is None:
|
||||
@@ -304,6 +310,31 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
|
||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
# Check if the models are the same to output loading informations
|
||||
with h5py.File(resolved_archive_file, 'r') as f:
|
||||
if 'layer_names' not in f.attrs and 'model_weights' in f:
|
||||
f = f['model_weights']
|
||||
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, 'layer_names'))
|
||||
model_layer_names = set(layer.name for layer in model.layers)
|
||||
missing_keys = list(model_layer_names - hdf5_layer_names)
|
||||
unexpected_keys = list(hdf5_layer_names - model_layer_names)
|
||||
error_msgs = []
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Layers of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Layers from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading weights for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if output_loading_info:
|
||||
loading_info = {"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"error_msgs": error_msgs}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
class TFConv1D(tf.keras.layers.Layer):
|
||||
|
||||
@@ -17,18 +17,22 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import os
|
||||
import logging
|
||||
import six
|
||||
|
||||
from .modeling_auto import (AutoModel, AutoModelForQuestionAnswering,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelWithLMHead)
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .file_utils import add_start_docstrings, is_tf_available, is_torch_available
|
||||
from .data.processors import SingleSentenceClassificationProcessor
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelWithLMHead)
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from .modeling_auto import (AutoModel, AutoModelForQuestionAnswering,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelWithLMHead)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,12 +65,6 @@ class TextClassificationPipeline(object):
|
||||
def __init__(self, tokenizer, model, is_compiled=False, is_trained=False):
|
||||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
if is_tf_available():
|
||||
self.framework = 'tf'
|
||||
elif is_torch_available():
|
||||
self.framework = 'pt'
|
||||
else:
|
||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||
self.is_compiled = is_compiled
|
||||
self.is_trained = is_trained
|
||||
|
||||
@@ -94,8 +92,11 @@ class TextClassificationPipeline(object):
|
||||
# used for both the tokenizer and the model
|
||||
model_kwargs[key] = kwargs[key]
|
||||
|
||||
model_kwargs['output_loading_info'] = True
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **tokenizer_kwargs)
|
||||
model_kwargs['output_loading_info'] = True
|
||||
if is_tf_available():
|
||||
model, loading_info = TFAutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
|
||||
else:
|
||||
model, loading_info = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
|
||||
|
||||
return cls(tokenizer, model, is_trained=bool(not loading_info['missing_keys']))
|
||||
@@ -109,36 +110,42 @@ class TextClassificationPipeline(object):
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
|
||||
def prepare_data(self, train_samples_text, train_samples_labels,
|
||||
valid_samples_text=None, valid_samples_labels=None,
|
||||
def prepare_data(self, x, y=None,
|
||||
validation_data=None,
|
||||
validation_split=0.1, **kwargs):
|
||||
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
|
||||
train_samples_labels)
|
||||
dataset = x
|
||||
if not isinstance(x, SingleSentenceClassificationProcessor):
|
||||
dataset = SingleSentenceClassificationProcessor.create_from_examples(x, y)
|
||||
num_data_samples = len(dataset)
|
||||
if valid_samples_text is not None and valid_samples_labels is not None:
|
||||
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text,
|
||||
valid_samples_labels)
|
||||
|
||||
if validation_data is not None:
|
||||
valid_dataset = validation_data
|
||||
if not isinstance(validation_data, SingleSentenceClassificationProcessor):
|
||||
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(validation_data)
|
||||
|
||||
num_valid_samples = len(valid_dataset)
|
||||
train_dataset = dataset
|
||||
num_train_samples = num_data_samples
|
||||
else:
|
||||
assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0"
|
||||
num_valid_samples = int(num_data_samples * validation_split)
|
||||
num_valid_samples = max(int(num_data_samples * validation_split), 1)
|
||||
num_train_samples = num_data_samples - num_valid_samples
|
||||
train_dataset = dataset[num_train_samples]
|
||||
valid_dataset = dataset[num_valid_samples]
|
||||
train_dataset = dataset[num_valid_samples:]
|
||||
valid_dataset = dataset[:num_valid_samples]
|
||||
|
||||
logger.info('Tokenizing and processing dataset')
|
||||
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework)
|
||||
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework)
|
||||
return train_dataset, valid_dataset, num_train_samples, num_valid_samples
|
||||
train_dataset = train_dataset.get_features(self.tokenizer,
|
||||
return_tensors='tf' if is_tf_available() else 'pt')
|
||||
valid_dataset = valid_dataset.get_features(self.tokenizer,
|
||||
return_tensors='tf' if is_tf_available() else 'pt')
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def compile(self, learning_rate=3e-5, epsilon=1e-8, **kwargs):
|
||||
if self.framework == 'tf':
|
||||
def compile(self, learning_rate=3e-5, adam_epsilon=1e-8, **kwargs):
|
||||
if is_tf_available():
|
||||
logger.info('Preparing model')
|
||||
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon)
|
||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=adam_epsilon)
|
||||
if USE_AMP:
|
||||
# loss scaling is currently required when using mixed precision
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||
@@ -150,39 +157,37 @@ class TextClassificationPipeline(object):
|
||||
self.is_compiled = True
|
||||
|
||||
|
||||
def fit(self, train_samples_text=None, train_samples_labels=None,
|
||||
valid_samples_text=None, valid_samples_labels=None,
|
||||
train_batch_size=None, valid_batch_size=None,
|
||||
def fit(self, X=None, y=None,
|
||||
validation_data=None,
|
||||
validation_split=0.1,
|
||||
train_batch_size=None,
|
||||
valid_batch_size=None,
|
||||
**kwargs):
|
||||
|
||||
# Generic compatibility with sklearn and Keras
|
||||
if 'y' in kwargs and train_samples_labels is None:
|
||||
train_samples_labels = kwargs.pop('y')
|
||||
if 'X' in kwargs and train_samples_text is None:
|
||||
train_samples_text = kwargs.pop('X')
|
||||
|
||||
if not self.is_compiled:
|
||||
self.compile(**kwargs)
|
||||
|
||||
datasets = self.prepare_data(train_samples_text, train_samples_labels,
|
||||
valid_samples_text, valid_samples_labels,
|
||||
validation_split)
|
||||
train_dataset, valid_dataset, num_train_samples, num_valid_samples = datasets
|
||||
train_dataset, valid_dataset = self.prepare_data(X, y=y,
|
||||
validation_data=validation_data,
|
||||
validation_split=validation_split)
|
||||
num_train_samples = len(train_dataset)
|
||||
num_valid_samples = len(valid_dataset)
|
||||
|
||||
train_steps = num_train_samples//train_batch_size
|
||||
valid_steps = num_valid_samples//valid_batch_size
|
||||
|
||||
if self.framework == 'tf':
|
||||
if is_tf_available():
|
||||
# Prepare dataset as a tf.train_data.Dataset instance
|
||||
train_dataset = train_dataset.shuffle(128).batch(train_batch_size).repeat(-1)
|
||||
valid_dataset = valid_dataset.batch(valid_batch_size)
|
||||
|
||||
logger.info('Training TF 2.0 model')
|
||||
history = self.model.fit(train_dataset, epochs=2, steps_per_epoch=train_steps,
|
||||
validation_data=valid_dataset, validation_steps=valid_steps, **kwargs)
|
||||
validation_data=valid_dataset, validation_steps=valid_steps,
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.is_trained = True
|
||||
|
||||
|
||||
@@ -210,9 +215,11 @@ class TextClassificationPipeline(object):
|
||||
if not self.is_trained:
|
||||
logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.")
|
||||
|
||||
inputs = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors=self.framework)
|
||||
inputs = self.tokenizer.batch_encode_plus(texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors='tf' if is_tf_available() else 'pt')
|
||||
|
||||
if self.framework == 'tf':
|
||||
if is_tf_available():
|
||||
# TODO trace model
|
||||
predictions = self.model(**inputs)[0]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user