diff --git a/setup.py b/setup.py index 19dad3b332..706293ee31 100644 --- a/setup.py +++ b/setup.py @@ -127,7 +127,9 @@ setup( "sacremoses", ], extras_require=extras, - scripts=["transformers-cli"], + entry_points={ + "console_scripts": ["transformers-cli=transformers.commands.transformers_cli:main"] + }, python_requires=">=3.6.0", classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/transformers-cli b/src/transformers/commands/transformers_cli.py old mode 100755 new mode 100644 similarity index 76% rename from transformers-cli rename to src/transformers/commands/transformers_cli.py index 9813b83843..ecc2ce96d9 --- a/transformers-cli +++ b/src/transformers/commands/transformers_cli.py @@ -8,9 +8,10 @@ from transformers.commands.run import RunCommand from transformers.commands.serving import ServeCommand from transformers.commands.user import UserCommands -if __name__ == '__main__': - parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli []') - commands_parser = parser.add_subparsers(help='transformers-cli command helpers') + +def main(): + parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli []") + commands_parser = parser.add_subparsers(help="transformers-cli command helpers") # Register commands ConvertCommand.register_subcommand(commands_parser) @@ -23,10 +24,14 @@ if __name__ == '__main__': # Let's go args = parser.parse_args() - if not hasattr(args, 'func'): + if not hasattr(args, "func"): parser.print_help() exit(1) # Run service = args.func(args) service.run() + + +if __name__ == "__main__": + main()