Templates overhaul 1 (#8993)
This commit is contained in:
@@ -19,12 +19,18 @@ from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from cookiecutter.main import cookiecutter
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
try:
|
||||
from cookiecutter.main import cookiecutter
|
||||
|
||||
_has_cookiecutter = True
|
||||
except ImportError:
|
||||
_has_cookiecutter = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -49,6 +55,11 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
self._path = path
|
||||
|
||||
def run(self):
|
||||
if not _has_cookiecutter:
|
||||
raise ImportError(
|
||||
"Model creation dependencies are required to use the `add_new_model` command. Install them by running "
|
||||
"the folowing at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n"
|
||||
)
|
||||
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
|
||||
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
|
||||
if len(directories) > 0:
|
||||
@@ -153,6 +164,11 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
f"{model_dir}/tokenization_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/tokenization_fast_{lowercase_model_name}.py",
|
||||
f"{model_dir}/tokenization_{lowercase_model_name}_fast.py",
|
||||
)
|
||||
|
||||
from os import fdopen, remove
|
||||
from shutil import copymode, move
|
||||
from tempfile import mkstemp
|
||||
|
||||
Reference in New Issue
Block a user