Add MultipleChoice to TFTrainer [WIP] (#4270)
* catch gpu len 1 set to gpu0 * Add mpc to trainer * Add MPC for TF * fix TF automodel for MPC and add Albert * Apply style * Fix import * Note to self: double check * Make shape None, None for datasetgenerator output shapes * Add from_pt bool which doesnt seem to work * Original checkpoint dir * Fix docstrings for automodel * Update readme and apply style * Colab should probably not be from users * Colabs should probably not be from users * Add colab * Update README.md * Update README.md * Cleanup __intit__ * Cleanup flake8 trailing comma * Update src/transformers/training_args_tf.py * Update src/transformers/modeling_tf_auto.py Co-authored-by: Viktor Alm <viktoralm@pop-os.localdomain> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ This is still a work-in-progress – in particular documentation is still sparse
|
|||||||
| [`language-modeling`](./language-modeling) | Raw text | ✅ | - | - | - | - |
|
| [`language-modeling`](./language-modeling) | Raw text | ✅ | - | - | - | - |
|
||||||
| [`text-classification`](./text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/trainer/01_text_classification.ipynb) | [](https://portal.azure.com/#create/Microsoft.Template/uri/https%3A%2F%2Fraw.githubusercontent.com%2FAzure%2Fazure-quickstart-templates%2Fmaster%2F101-storage-account-create%2Fazuredeploy.json) |
|
| [`text-classification`](./text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/trainer/01_text_classification.ipynb) | [](https://portal.azure.com/#create/Microsoft.Template/uri/https%3A%2F%2Fraw.githubusercontent.com%2FAzure%2Fazure-quickstart-templates%2Fmaster%2F101-storage-account-create%2Fazuredeploy.json) |
|
||||||
| [`token-classification`](./token-classification) | CoNLL NER | ✅ | ✅ | ✅ | - | - |
|
| [`token-classification`](./token-classification) | CoNLL NER | ✅ | ✅ | ✅ | - | - |
|
||||||
| [`multiple-choice`](./multiple-choice) | SWAG, RACE, ARC | ✅ | - | - | - | - |
|
| [`multiple-choice`](./multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb) | - |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,3 +29,28 @@ Training with the defined hyper-parameters yields the following results:
|
|||||||
eval_acc = 0.8338998300509847
|
eval_acc = 0.8338998300509847
|
||||||
eval_loss = 0.44457291918821606
|
eval_loss = 0.44457291918821606
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Tensorflow
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SWAG_DIR=/path/to/swag_data_dir
|
||||||
|
python ./examples/multiple-choice/run_tf_multiple_choice.py \
|
||||||
|
--task_name swag \
|
||||||
|
--model_name_or_path bert-base-cased \
|
||||||
|
--do_train \
|
||||||
|
--do_eval \
|
||||||
|
--data_dir $SWAG_DIR \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3 \
|
||||||
|
--max_seq_length 80 \
|
||||||
|
--output_dir models_bert/swag_base \
|
||||||
|
--per_gpu_eval_batch_size=16 \
|
||||||
|
--per_gpu_train_batch_size=16 \
|
||||||
|
--logging-dir logs \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--overwrite_output
|
||||||
|
```
|
||||||
|
|
||||||
|
# Run it in colab
|
||||||
|
[](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb)
|
||||||
|
|||||||
211
examples/multiple-choice/run_tf_multiple_choice.py
Normal file
211
examples/multiple-choice/run_tf_multiple_choice.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
EvalPrediction,
|
||||||
|
HfArgumentParser,
|
||||||
|
TFAutoModelForMultipleChoice,
|
||||||
|
TFTrainer,
|
||||||
|
TFTrainingArguments,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from utils_multiple_choice import Split, TFMultipleChoiceDataset, processors
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def simple_accuracy(preds, labels):
|
||||||
|
return (preds == labels).mean()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name_or_path: str = field(
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
|
)
|
||||||
|
config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
tokenizer_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
cache_dir: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataTrainingArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(processors.keys())})
|
||||||
|
data_dir: str = field(metadata={"help": "Should contain the data files for the task."})
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=128,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
overwrite_cache: bool = field(
|
||||||
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
|
# or by passing the --help flag to this script.
|
||||||
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
if (
|
||||||
|
os.path.exists(training_args.output_dir)
|
||||||
|
and os.listdir(training_args.output_dir)
|
||||||
|
and training_args.do_train
|
||||||
|
and not training_args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"device: %s, n_gpu: %s, 16-bits training: %s", training_args.device, training_args.n_gpu, training_args.fp16,
|
||||||
|
)
|
||||||
|
logger.info("Training/evaluation parameters %s", training_args)
|
||||||
|
|
||||||
|
# Set seed
|
||||||
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = processors[data_args.task_name]()
|
||||||
|
label_list = processor.get_labels()
|
||||||
|
num_labels = len(label_list)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("Task not found: %s" % (data_args.task_name))
|
||||||
|
|
||||||
|
# Load pretrained model and tokenizer
|
||||||
|
#
|
||||||
|
# Distributed training:
|
||||||
|
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||||
|
# download model & vocab.
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||||
|
num_labels=num_labels,
|
||||||
|
finetuning_task=data_args.task_name,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
with training_args.strategy.scope():
|
||||||
|
model = TFAutoModelForMultipleChoice.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
from_pt=bool(".bin" in model_args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
# Get datasets
|
||||||
|
train_dataset = (
|
||||||
|
TFMultipleChoiceDataset(
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
task=data_args.task_name,
|
||||||
|
max_seq_length=data_args.max_seq_length,
|
||||||
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
|
mode=Split.train,
|
||||||
|
)
|
||||||
|
if training_args.do_train
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
eval_dataset = (
|
||||||
|
TFMultipleChoiceDataset(
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
task=data_args.task_name,
|
||||||
|
max_seq_length=data_args.max_seq_length,
|
||||||
|
overwrite_cache=data_args.overwrite_cache,
|
||||||
|
mode=Split.dev,
|
||||||
|
)
|
||||||
|
if training_args.do_eval
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||||
|
preds = np.argmax(p.predictions, axis=1)
|
||||||
|
return {"acc": simple_accuracy(preds, p.label_ids)}
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = TFTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset.get_dataset() if train_dataset else None,
|
||||||
|
eval_dataset=eval_dataset.get_dataset() if eval_dataset else None,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model()
|
||||||
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
# Evaluation
|
||||||
|
results = {}
|
||||||
|
if training_args.do_eval:
|
||||||
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
|
result = trainer.evaluate()
|
||||||
|
|
||||||
|
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
||||||
|
with open(output_eval_file, "w") as writer:
|
||||||
|
logger.info("***** Eval results *****")
|
||||||
|
for key, value in result.items():
|
||||||
|
logger.info(" %s = %s", key, value)
|
||||||
|
writer.write("%s = %s\n" % (key, value))
|
||||||
|
|
||||||
|
results.update(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -25,11 +25,9 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, torch_distributed_zero_first
|
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -76,6 +74,11 @@ class Split(Enum):
|
|||||||
test = "test"
|
test = "test"
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from transformers import torch_distributed_zero_first
|
||||||
|
|
||||||
class MultipleChoiceDataset(Dataset):
|
class MultipleChoiceDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
This will be superseded by a framework-agnostic approach
|
This will be superseded by a framework-agnostic approach
|
||||||
@@ -138,6 +141,95 @@ class MultipleChoiceDataset(Dataset):
|
|||||||
return self.features[i]
|
return self.features[i]
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
class TFMultipleChoiceDataset:
|
||||||
|
"""
|
||||||
|
This will be superseded by a framework-agnostic approach
|
||||||
|
soon.
|
||||||
|
"""
|
||||||
|
|
||||||
|
features: List[InputFeatures]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_dir: str,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
task: str,
|
||||||
|
max_seq_length: Optional[int] = 128,
|
||||||
|
overwrite_cache=False,
|
||||||
|
mode: Split = Split.train,
|
||||||
|
):
|
||||||
|
processor = processors[task]()
|
||||||
|
|
||||||
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||||
|
label_list = processor.get_labels()
|
||||||
|
if mode == Split.dev:
|
||||||
|
examples = processor.get_dev_examples(data_dir)
|
||||||
|
elif mode == Split.test:
|
||||||
|
examples = processor.get_test_examples(data_dir)
|
||||||
|
else:
|
||||||
|
examples = processor.get_train_examples(data_dir)
|
||||||
|
logger.info("Training examples: %s", len(examples))
|
||||||
|
# TODO clean up all this to leverage built-in features of tokenizers
|
||||||
|
self.features = convert_examples_to_features(
|
||||||
|
examples,
|
||||||
|
label_list,
|
||||||
|
max_seq_length,
|
||||||
|
tokenizer,
|
||||||
|
pad_on_left=bool(tokenizer.padding_side == "left"),
|
||||||
|
pad_token=tokenizer.pad_token_id,
|
||||||
|
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||||
|
if ex_index % 10000 == 0:
|
||||||
|
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||||
|
|
||||||
|
yield (
|
||||||
|
{
|
||||||
|
"example_id": 0,
|
||||||
|
"input_ids": ex.input_ids,
|
||||||
|
"attention_mask": ex.attention_mask,
|
||||||
|
"token_type_ids": ex.token_type_ids,
|
||||||
|
},
|
||||||
|
ex.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dataset = tf.data.Dataset.from_generator(
|
||||||
|
gen,
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"example_id": tf.int32,
|
||||||
|
"input_ids": tf.int32,
|
||||||
|
"attention_mask": tf.int32,
|
||||||
|
"token_type_ids": tf.int32,
|
||||||
|
},
|
||||||
|
tf.int64,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"example_id": tf.TensorShape([]),
|
||||||
|
"input_ids": tf.TensorShape([None, None]),
|
||||||
|
"attention_mask": tf.TensorShape([None, None]),
|
||||||
|
"token_type_ids": tf.TensorShape([None, None]),
|
||||||
|
},
|
||||||
|
tf.TensorShape([]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_dataset(self):
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.features)
|
||||||
|
|
||||||
|
def __getitem__(self, i) -> InputFeatures:
|
||||||
|
return self.features[i]
|
||||||
|
|
||||||
|
|
||||||
class DataProcessor:
|
class DataProcessor:
|
||||||
"""Base class for data converters for multiple choice data sets."""
|
"""Base class for data converters for multiple choice data sets."""
|
||||||
|
|
||||||
@@ -225,6 +317,52 @@ class RaceProcessor(DataProcessor):
|
|||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
class SynonymProcessor(DataProcessor):
|
||||||
|
"""Processor for the Synonym data set."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
logger.info("LOOKING AT {} train".format(data_dir))
|
||||||
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctrain.csv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||||
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "mchp.csv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||||
|
|
||||||
|
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctest.csv")), "test")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["0", "1", "2", "3", "4"]
|
||||||
|
|
||||||
|
def _read_csv(self, input_file):
|
||||||
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
|
return list(csv.reader(f))
|
||||||
|
|
||||||
|
def _create_examples(self, lines: List[List[str]], type: str):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
InputExample(
|
||||||
|
example_id=line[0],
|
||||||
|
question="", # in the swag dataset, the
|
||||||
|
# common beginning of each
|
||||||
|
# choice is stored in "sent2".
|
||||||
|
contexts=[line[1], line[1], line[1], line[1], line[1]],
|
||||||
|
endings=[line[2], line[3], line[4], line[5], line[6]],
|
||||||
|
label=line[7],
|
||||||
|
)
|
||||||
|
for line in lines # we skip the line with the column names
|
||||||
|
]
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
class SwagProcessor(DataProcessor):
|
class SwagProcessor(DataProcessor):
|
||||||
"""Processor for the SWAG data set."""
|
"""Processor for the SWAG data set."""
|
||||||
|
|
||||||
@@ -435,7 +573,5 @@ def convert_examples_to_features(
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor}
|
processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor, "syn": SynonymProcessor}
|
||||||
|
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4, "syn", 5}
|
||||||
|
|
||||||
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4}
|
|
||||||
|
|||||||
@@ -359,6 +359,7 @@ if is_tf_available():
|
|||||||
from .modeling_tf_auto import (
|
from .modeling_tf_auto import (
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForPreTraining,
|
TFAutoModelForPreTraining,
|
||||||
|
TFAutoModelForMultipleChoice,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
@@ -493,6 +494,7 @@ if is_tf_available():
|
|||||||
TFAlbertModel,
|
TFAlbertModel,
|
||||||
TFAlbertForPreTraining,
|
TFAlbertForPreTraining,
|
||||||
TFAlbertForMaskedLM,
|
TFAlbertForMaskedLM,
|
||||||
|
TFAlbertForMultipleChoice,
|
||||||
TFAlbertForSequenceClassification,
|
TFAlbertForSequenceClassification,
|
||||||
TFAlbertForQuestionAnswering,
|
TFAlbertForQuestionAnswering,
|
||||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import logging
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
@@ -957,3 +957,127 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
|||||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||||
|
|
||||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Albert Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
ALBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
|
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||||
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = tf.keras.layers.Dense(
|
||||||
|
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
|
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import AlbertTokenizer, TFAlbertForMultipleChoice
|
||||||
|
|
||||||
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
|
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
|
||||||
|
|
||||||
|
example1 = ["This is a context", "Is it a context? Yes"]
|
||||||
|
example2 = ["This is a context", "Is it a context? No"]
|
||||||
|
encoding = tokenizer.batch_encode_plus([example1, example2], return_tensors='tf', truncation_strategy="only_first", pad_to_max_length=True, max_length=128)
|
||||||
|
outputs = model(encoding["input_ids"][None, :])
|
||||||
|
logits = outputs[0]
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
|
elif isinstance(inputs, dict):
|
||||||
|
print("isdict(1)")
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
print(input_ids)
|
||||||
|
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
num_choices = shape_list(input_ids)[1]
|
||||||
|
seq_length = shape_list(input_ids)[2]
|
||||||
|
else:
|
||||||
|
num_choices = shape_list(inputs_embeds)[1]
|
||||||
|
seq_length = shape_list(inputs_embeds)[2]
|
||||||
|
|
||||||
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
|
flat_inputs = [
|
||||||
|
flat_input_ids,
|
||||||
|
flat_attention_mask,
|
||||||
|
flat_token_type_ids,
|
||||||
|
flat_position_ids,
|
||||||
|
head_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.albert(flat_inputs, training=training)
|
||||||
|
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
return outputs # reshaped_logits, (hidden_states), (attentions)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from .configuration_utils import PretrainedConfig
|
|||||||
from .modeling_tf_albert import (
|
from .modeling_tf_albert import (
|
||||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
TFAlbertForMaskedLM,
|
TFAlbertForMaskedLM,
|
||||||
|
TFAlbertForMultipleChoice,
|
||||||
TFAlbertForPreTraining,
|
TFAlbertForPreTraining,
|
||||||
TFAlbertForQuestionAnswering,
|
TFAlbertForQuestionAnswering,
|
||||||
TFAlbertForSequenceClassification,
|
TFAlbertForSequenceClassification,
|
||||||
@@ -44,6 +45,7 @@ from .modeling_tf_albert import (
|
|||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
|
TFBertForMultipleChoice,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
@@ -172,6 +174,10 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||||
|
[(BertConfig, TFBertForMultipleChoice), (AlbertConfig, TFAlbertForMultipleChoice)]
|
||||||
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||||
@@ -662,6 +668,153 @@ class TFAutoModelWithLMHead(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForMultipleChoice:
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.TFAutoModelForMultipleChoice` is a generic model class
|
||||||
|
that will be instantiated as one of the multiple choice model classes of the library
|
||||||
|
when created with the `TFAutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)`
|
||||||
|
class method.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||||
|
|
||||||
|
The model class to instantiate is selected as the first pattern matching
|
||||||
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
|
- contains `albert`: TFAlbertForMultipleChoice (Albert model)
|
||||||
|
- contains `bert`: TFBertForMultipleChoice (Bert model)
|
||||||
|
|
||||||
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"TFAutoModelForMultipleChoice is designed to be instantiated "
|
||||||
|
"using the `TFAutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`TFAutoModelForMultipleChoice.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
r""" Instantiates one of the base model classes of the library
|
||||||
|
from a configuration.
|
||||||
|
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
- isInstance of `albert` configuration class: AlbertModel (Albert model)
|
||||||
|
- isInstance of `bert` configuration class: BertModel (Bert model)
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
|
model = AutoModelForMulitpleChoice.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
"""
|
||||||
|
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class(config)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r""" Instantiates one of the multiple choice model classes of the library
|
||||||
|
from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||||
|
|
||||||
|
The model class to instantiate is selected as the first pattern matching
|
||||||
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
|
- contains `albert`: TFRobertaForMultiple (Albert model)
|
||||||
|
- contains `bert`: TFBertForMultipleChoice (Bert model)
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
|
Params:
|
||||||
|
pretrained_model_name_or_path: either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.
|
||||||
|
|
||||||
|
from_pt: (`Optional`) Boolean
|
||||||
|
Set to True if the Checkpoint is a PyTorch checkpoint.
|
||||||
|
|
||||||
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
|
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||||
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
|
state_dict: (`optional`) dict:
|
||||||
|
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||||
|
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
|
||||||
|
cache_dir: (`optional`) string:
|
||||||
|
Path to a directory in which a downloaded pre-trained model
|
||||||
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
resume_download: (`optional`) boolean, default False:
|
||||||
|
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
|
output_loading_info: (`optional`) boolean:
|
||||||
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||||
|
|
||||||
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
||||||
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
model = TFAutoModelFormultipleChoice.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||||
|
model = TFAutoModelFormultipleChoice.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
model = TFAutoModelFormultipleChoice.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
||||||
|
assert model.config.output_attention == True
|
||||||
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||||
|
model = TFAutoModelFormultipleChoice.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModelForSequenceClassification(object):
|
class TFAutoModelForSequenceClassification(object):
|
||||||
r"""
|
r"""
|
||||||
:class:`~transformers.TFAutoModelForSequenceClassification` is a generic model class
|
:class:`~transformers.TFAutoModelForSequenceClassification` is a generic model class
|
||||||
|
|||||||
@@ -125,7 +125,9 @@ class TFTrainer:
|
|||||||
in the Tensorflow documentation and those contained in the transformers library.
|
in the Tensorflow documentation and those contained in the transformers library.
|
||||||
"""
|
"""
|
||||||
if self.args.optimizer_name == "adamw":
|
if self.args.optimizer_name == "adamw":
|
||||||
self.optimizer = create_optimizer(self.args.learning_rate, self.train_steps, self.args.warmup_steps)
|
self.optimizer = create_optimizer(
|
||||||
|
self.args.learning_rate, self.train_steps, self.args.warmup_steps, self.args.end_lr
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.optimizer = tf.keras.optimizers.get(
|
self.optimizer = tf.keras.optimizers.get(
|
||||||
@@ -139,6 +141,7 @@ class TFTrainer:
|
|||||||
self.optimizer = tf.keras.optimizers.get(
|
self.optimizer = tf.keras.optimizers.get(
|
||||||
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
|
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
|
||||||
)
|
)
|
||||||
|
logger.info("Created an/a {} optimizer".format(self.optimizer))
|
||||||
|
|
||||||
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
|
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -149,6 +152,7 @@ class TFTrainer:
|
|||||||
load_model: if we want to start the training from the latest checkpoint.
|
load_model: if we want to start the training from the latest checkpoint.
|
||||||
"""
|
"""
|
||||||
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
||||||
|
|
||||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=max_to_keep)
|
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=max_to_keep)
|
||||||
|
|
||||||
if load_model:
|
if load_model:
|
||||||
@@ -425,5 +429,6 @@ class TFTrainer:
|
|||||||
|
|
||||||
path = os.path.join(self.args.output_dir, "saved_model")
|
path = os.path.join(self.args.output_dir, "saved_model")
|
||||||
|
|
||||||
|
logger.info("Saving model in {}".format(path))
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
self.model.save_pretrained(self.args.output_dir)
|
self.model.save_pretrained(self.args.output_dir)
|
||||||
|
|||||||
@@ -30,6 +30,12 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
"help": "Name of a Tensorflow loss. For the list see: https://www.tensorflow.org/api_docs/python/tf/keras/losses"
|
"help": "Name of a Tensorflow loss. For the list see: https://www.tensorflow.org/api_docs/python/tf/keras/losses"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
tpu_name: str = field(
|
||||||
|
default=None, metadata={"help": "Name of TPU"},
|
||||||
|
)
|
||||||
|
end_lr: float = field(
|
||||||
|
default=0, metadata={"help": "End learning rate for optimizer"},
|
||||||
|
)
|
||||||
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
||||||
debug: bool = field(
|
debug: bool = field(
|
||||||
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
|
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
|
||||||
@@ -45,6 +51,9 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
if self.tpu_name:
|
||||||
|
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
|
||||||
|
else:
|
||||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
tpu = None
|
tpu = None
|
||||||
|
|||||||
Reference in New Issue
Block a user