Move source code inside a src subdirectory.
This prevents transformers from being importable simply because the CWD
is the root of the git repository, while not being importable from other
directories. That led to inconsistent behavior, especially in examples.
Once you fetch this commit, in your dev environment, you must run:
$ pip uninstall transformers
$ pip install -e .
This commit is contained in:
13
src/transformers/commands/__init__.py
Normal file
13
src/transformers/commands/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
class BaseTransformersCLICommand(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
raise NotImplementedError()
|
||||
144
src/transformers/commands/convert.py
Normal file
144
src/transformers/commands/convert.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from logging import getLogger
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def convert_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return ConvertCommand(
|
||||
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
|
||||
)
|
||||
|
||||
|
||||
class ConvertCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser(
|
||||
"convert",
|
||||
help="CLI tool to run convert model from original "
|
||||
"author checkpoints to Transformesr PyTorch checkpoints.",
|
||||
)
|
||||
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
|
||||
train_parser.add_argument(
|
||||
"--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
|
||||
)
|
||||
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
|
||||
train_parser.add_argument(
|
||||
"--finetuning_task_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional fine-tuning task name if the TF model was a finetuned model.",
|
||||
)
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
tf_checkpoint: str,
|
||||
pytorch_dump_output: str,
|
||||
config: str,
|
||||
finetuning_task_name: str,
|
||||
*args
|
||||
):
|
||||
self._logger = getLogger("transformers-cli/converting")
|
||||
|
||||
self._logger.info("Loading model {}".format(model_type))
|
||||
self._model_type = model_type
|
||||
self._tf_checkpoint = tf_checkpoint
|
||||
self._pytorch_dump_output = pytorch_dump_output
|
||||
self._config = config
|
||||
self._finetuning_task_name = finetuning_task_name
|
||||
|
||||
def run(self):
|
||||
if self._model_type == "bert":
|
||||
try:
|
||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
convert_openai_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "transfo_xl":
|
||||
try:
|
||||
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
convert_transfo_xl_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if "ckpt" in self._tf_checkpoint.lower():
|
||||
TF_CHECKPOINT = self._tf_checkpoint
|
||||
TF_DATASET_FILE = ""
|
||||
else:
|
||||
TF_DATASET_FILE = self._tf_checkpoint
|
||||
TF_CHECKPOINT = ""
|
||||
convert_transfo_xl_checkpoint_to_pytorch(
|
||||
TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
|
||||
)
|
||||
elif self._model_type == "gpt2":
|
||||
try:
|
||||
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
convert_gpt2_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
try:
|
||||
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
convert_xlnet_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
convert_xlnet_checkpoint_to_pytorch(
|
||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||
)
|
||||
elif self._model_type == "xlm":
|
||||
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_xlm_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||
else:
|
||||
raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]")
|
||||
32
src/transformers/commands/download.py
Normal file
32
src/transformers/commands/download.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def download_command_factory(args):
|
||||
return DownloadCommand(args.model, args.cache_dir, args.force)
|
||||
|
||||
|
||||
class DownloadCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
download_parser = parser.add_parser("download")
|
||||
download_parser.add_argument(
|
||||
"--cache-dir", type=str, default=None, help="Path to location to store the models"
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
|
||||
)
|
||||
download_parser.add_argument("model", type=str, help="Name of the model to download")
|
||||
download_parser.set_defaults(func=download_command_factory)
|
||||
|
||||
def __init__(self, model: str, cache: str, force: bool):
|
||||
self._model = model
|
||||
self._cache = cache
|
||||
self._force = force
|
||||
|
||||
def run(self):
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
96
src/transformers/commands/run.py
Normal file
96
src/transformers/commands/run.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def try_infer_format_from_ext(path: str):
|
||||
if not path:
|
||||
return "pipe"
|
||||
|
||||
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
||||
if path.endswith(ext):
|
||||
return ext
|
||||
|
||||
raise Exception(
|
||||
"Unable to determine file format from file extension {}. "
|
||||
"Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
||||
)
|
||||
|
||||
|
||||
def run_command_factory(args):
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
)
|
||||
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
|
||||
reader = PipelineDataFormat.from_str(
|
||||
format=format,
|
||||
output_path=args.output,
|
||||
input_path=args.input,
|
||||
column=args.column if args.column else nlp.default_input_names,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
return RunCommand(nlp, reader)
|
||||
|
||||
|
||||
class RunCommand(BaseTransformersCLICommand):
|
||||
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
||||
self._nlp = nlp
|
||||
self._reader = reader
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
||||
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
|
||||
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
||||
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
||||
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
||||
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
|
||||
run_parser.add_argument(
|
||||
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--column",
|
||||
type=str,
|
||||
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
default="infer",
|
||||
choices=PipelineDataFormat.SUPPORTED_FORMATS,
|
||||
help="Input format to read from",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
|
||||
run_parser.set_defaults(func=run_command_factory)
|
||||
|
||||
def run(self):
|
||||
nlp, outputs = self._nlp, []
|
||||
|
||||
for entry in self._reader:
|
||||
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
|
||||
if isinstance(output, dict):
|
||||
outputs.append(output)
|
||||
else:
|
||||
outputs += output
|
||||
|
||||
# Saving data
|
||||
if self._nlp.binary_output:
|
||||
binary_path = self._reader.save_binary(outputs)
|
||||
logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
|
||||
else:
|
||||
self._reader.save(outputs)
|
||||
185
src/transformers/commands/serving.py
Normal file
185
src/transformers/commands/serving.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
|
||||
try:
|
||||
from uvicorn import run
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from pydantic import BaseModel
|
||||
|
||||
_serve_dependancies_installed = True
|
||||
except (ImportError, AttributeError):
|
||||
BaseModel = object
|
||||
|
||||
def Body(*x, **y):
|
||||
pass
|
||||
|
||||
_serve_dependancies_installed = False
|
||||
|
||||
|
||||
logger = logging.getLogger("transformers-cli/serving")
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
)
|
||||
return ServeCommand(nlp, args.host, args.port)
|
||||
|
||||
|
||||
class ServeModelInfoResult(BaseModel):
|
||||
"""
|
||||
Expose model information
|
||||
"""
|
||||
|
||||
infos: dict
|
||||
|
||||
|
||||
class ServeTokenizeResult(BaseModel):
|
||||
"""
|
||||
Tokenize result model
|
||||
"""
|
||||
|
||||
tokens: List[str]
|
||||
tokens_ids: Optional[List[int]]
|
||||
|
||||
|
||||
class ServeDeTokenizeResult(BaseModel):
|
||||
"""
|
||||
DeTokenize result model
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class ServeForwardResult(BaseModel):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
|
||||
output: Any
|
||||
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
serve_parser = parser.add_parser(
|
||||
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||
serve_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
||||
|
||||
self._pipeline = pipeline
|
||||
|
||||
self._host = host
|
||||
self._port = port
|
||||
if not _serve_dependancies_installed:
|
||||
raise ImportError(
|
||||
"Using serve command requires FastAPI and unicorn. "
|
||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||
"Or install FastAPI and unicorn separatly."
|
||||
)
|
||||
else:
|
||||
logger.info("Serving model over {}:{}".format(host, port))
|
||||
self._app = FastAPI()
|
||||
|
||||
# Register routes
|
||||
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
|
||||
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
|
||||
self._app.add_api_route(
|
||||
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
|
||||
)
|
||||
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self._host, port=self._port)
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
||||
|
||||
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
||||
"""
|
||||
Tokenize the provided input and eventually returns corresponding tokens id:
|
||||
- **text_input**: String to tokenize
|
||||
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
|
||||
"""
|
||||
try:
|
||||
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
||||
|
||||
if return_ids:
|
||||
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
else:
|
||||
return ServeTokenizeResult(tokens=tokens_txt)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
|
||||
def detokenize(
|
||||
self,
|
||||
tokens_ids: List[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
||||
):
|
||||
"""
|
||||
Detokenize the provided tokens ids to readable text:
|
||||
- **tokens_ids**: List of tokens ids
|
||||
- **skip_special_tokens**: Flag indicating to not try to decode special tokens
|
||||
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model="", text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
|
||||
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**:
|
||||
**attention_mask**:
|
||||
**tokens_type_ids**:
|
||||
"""
|
||||
|
||||
# Check we don't have empty string
|
||||
if len(inputs) == 0:
|
||||
return ServeForwardResult(output=[], attention=[])
|
||||
|
||||
try:
|
||||
# Forward through the model
|
||||
output = self._pipeline(inputs)
|
||||
return ServeForwardResult(output=output)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, {"error": str(e)})
|
||||
144
src/transformers/commands/train.py
Normal file
144
src/transformers/commands/train.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from logging import getLogger
|
||||
|
||||
from transformers import SingleSentenceClassificationProcessor as Processor
|
||||
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||
|
||||
# TF training parameters
|
||||
USE_XLA = False
|
||||
USE_AMP = False
|
||||
|
||||
|
||||
def train_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return TrainCommand(args)
|
||||
|
||||
|
||||
class TrainCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
:param parser: Root parser to register command-specific arguments
|
||||
:return:
|
||||
"""
|
||||
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
|
||||
|
||||
train_parser.add_argument(
|
||||
"--train_data",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to train (and optionally evaluation) dataset as a csv with "
|
||||
"tab separated labels and sentences.",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
|
||||
)
|
||||
|
||||
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
|
||||
train_parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
|
||||
)
|
||||
|
||||
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
|
||||
|
||||
train_parser.add_argument(
|
||||
"--task", type=str, default="text_classification", help="Task to train the model on."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
|
||||
)
|
||||
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
|
||||
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
|
||||
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
|
||||
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
|
||||
train_parser.set_defaults(func=train_command_factory)
|
||||
|
||||
def __init__(self, args: Namespace):
|
||||
self.logger = getLogger("transformers-cli/training")
|
||||
|
||||
self.framework = "tf" if is_tf_available() else "torch"
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
assert os.path.isdir(args.output)
|
||||
self.output = args.output
|
||||
|
||||
self.column_label = args.column_label
|
||||
self.column_text = args.column_text
|
||||
self.column_id = args.column_id
|
||||
|
||||
self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
|
||||
if args.task == "text_classification":
|
||||
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
||||
elif args.task == "token_classification":
|
||||
raise NotImplementedError
|
||||
elif args.task == "question_answering":
|
||||
raise NotImplementedError
|
||||
|
||||
self.logger.info("Loading dataset from {}".format(args.train_data))
|
||||
self.train_dataset = Processor.create_from_csv(
|
||||
args.train_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row,
|
||||
)
|
||||
self.valid_dataset = None
|
||||
if args.validation_data:
|
||||
self.logger.info("Loading validation dataset from {}".format(args.validation_data))
|
||||
self.valid_dataset = Processor.create_from_csv(
|
||||
args.validation_data,
|
||||
column_label=args.column_label,
|
||||
column_text=args.column_text,
|
||||
column_id=args.column_id,
|
||||
skip_first_row=args.skip_first_row,
|
||||
)
|
||||
|
||||
self.validation_split = args.validation_split
|
||||
self.train_batch_size = args.train_batch_size
|
||||
self.valid_batch_size = args.valid_batch_size
|
||||
self.learning_rate = args.learning_rate
|
||||
self.adam_epsilon = args.adam_epsilon
|
||||
|
||||
def run(self):
|
||||
if self.framework == "tf":
|
||||
return self.run_tf()
|
||||
return self.run_torch()
|
||||
|
||||
def run_torch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def run_tf(self):
|
||||
self.pipeline.fit(
|
||||
self.train_dataset,
|
||||
validation_data=self.valid_dataset,
|
||||
validation_split=self.validation_split,
|
||||
learning_rate=self.learning_rate,
|
||||
adam_epsilon=self.adam_epsilon,
|
||||
train_batch_size=self.train_batch_size,
|
||||
valid_batch_size=self.valid_batch_size,
|
||||
)
|
||||
|
||||
# Save trained pipeline
|
||||
self.pipeline.save_pretrained(self.output)
|
||||
174
src/transformers/commands/user.py
Normal file
174
src/transformers/commands/user.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from getpass import getpass
|
||||
from typing import List, Union
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.hf_api import HfApi, HfFolder
|
||||
|
||||
|
||||
class UserCommands(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
login_parser = parser.add_parser("login")
|
||||
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
||||
whoami_parser = parser.add_parser("whoami")
|
||||
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||
logout_parser = parser.add_parser("logout")
|
||||
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||
list_parser = parser.add_parser("ls")
|
||||
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||
# upload
|
||||
upload_parser = parser.add_parser("upload")
|
||||
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
||||
upload_parser.add_argument(
|
||||
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
||||
)
|
||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||
|
||||
|
||||
class ANSI:
|
||||
"""
|
||||
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||
"""
|
||||
|
||||
_bold = u"\u001b[1m"
|
||||
_reset = u"\u001b[0m"
|
||||
|
||||
@classmethod
|
||||
def bold(cls, s):
|
||||
return "{}{}{}".format(cls._bold, s, cls._reset)
|
||||
|
||||
|
||||
class BaseUserCommand:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self._api = HfApi()
|
||||
|
||||
|
||||
class LoginCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
print(
|
||||
"""
|
||||
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
||||
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
||||
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
||||
|
||||
"""
|
||||
)
|
||||
username = input("Username: ")
|
||||
password = getpass()
|
||||
try:
|
||||
token = self._api.login(username, password)
|
||||
except HTTPError as e:
|
||||
# probably invalid credentials, display error message.
|
||||
print(e)
|
||||
exit(1)
|
||||
HfFolder.save_token(token)
|
||||
print("Login successful")
|
||||
print("Your token:", token, "\n")
|
||||
print("Your token has been saved to", HfFolder.path_token)
|
||||
|
||||
|
||||
class WhoamiCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit()
|
||||
try:
|
||||
user = self._api.whoami(token)
|
||||
print(user)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
|
||||
|
||||
class LogoutCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit()
|
||||
HfFolder.delete_token()
|
||||
self._api.logout(token)
|
||||
print("Successfully logged out.")
|
||||
|
||||
|
||||
class ListObjsCommand(BaseUserCommand):
|
||||
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||
"""
|
||||
Inspired by:
|
||||
stackoverflow.com/a/8356620/593036
|
||||
stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||
"""
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(row_format.format(*headers))
|
||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||
for row in rows:
|
||||
lines.append(row_format.format(*row))
|
||||
return "\n".join(lines)
|
||||
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
objs = self._api.list_objs(token)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
exit(1)
|
||||
if len(objs) == 0:
|
||||
print("No shared file yet")
|
||||
exit()
|
||||
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
||||
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||
|
||||
|
||||
class UploadCommand(BaseUserCommand):
|
||||
def walk_dir(self, rel_path):
|
||||
"""
|
||||
Recursively list all files in a folder.
|
||||
"""
|
||||
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
||||
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
|
||||
for f in entries:
|
||||
if f.is_dir():
|
||||
files += self.walk_dir(f.path)
|
||||
return files
|
||||
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
local_path = os.path.abspath(self.args.path)
|
||||
if os.path.isdir(local_path):
|
||||
if self.args.filename is not None:
|
||||
raise ValueError("Cannot specify a filename override when uploading a folder.")
|
||||
rel_path = os.path.basename(local_path)
|
||||
files = self.walk_dir(rel_path)
|
||||
elif os.path.isfile(local_path):
|
||||
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
|
||||
files = [(local_path, filename)]
|
||||
else:
|
||||
raise ValueError("Not a valid file or directory: {}".format(local_path))
|
||||
|
||||
for filepath, filename in files:
|
||||
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
|
||||
|
||||
choice = input("Proceed? [Y/n] ").lower()
|
||||
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||
print("Abort")
|
||||
exit()
|
||||
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
||||
for filepath, filename in files:
|
||||
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
|
||||
print("Your file now lives at:")
|
||||
print(access_url)
|
||||
Reference in New Issue
Block a user