Fix model templates (#8595)
* First fixes * Fix imports and add init * Fix typo * Move init to final dest * Fix tokenization import * More fixes * Styling
This commit is contained in:
@@ -47,7 +47,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
path_to_transformer_root = (
|
||||
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
|
||||
)
|
||||
path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter"
|
||||
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
|
||||
|
||||
# Execute cookiecutter
|
||||
if not self._testing:
|
||||
@@ -75,9 +75,16 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
output_pytorch = "PyTorch" in pytorch_or_tensorflow
|
||||
output_tensorflow = "TensorFlow" in pytorch_or_tensorflow
|
||||
|
||||
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/__init__.py",
|
||||
f"{model_dir}/__init__.py",
|
||||
)
|
||||
shutil.move(
|
||||
f"{directory}/configuration_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py",
|
||||
f"{model_dir}/configuration_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
def remove_copy_lines(path):
|
||||
@@ -94,7 +101,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/modeling_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py",
|
||||
f"{model_dir}/modeling_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
@@ -111,7 +118,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/modeling_tf_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py",
|
||||
f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
@@ -129,7 +136,7 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/tokenization_{lowercase_model_name}.py",
|
||||
f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py",
|
||||
f"{model_dir}/tokenization_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
from os import fdopen, remove
|
||||
|
||||
Reference in New Issue
Block a user