Reorganize repo (#8580)
* Put models in subfolders * Styling * Fix imports in tests * More fixes in test imports * Sneaky hidden imports * Fix imports in doc files * More sneaky imports * Finish fixing tests * Fix examples * Fix path for copies * More fixes for examples * Fix dummy files * More fixes for example * More model import fixes * Is this why you're unhappy GitHub? * Fix imports in conver command
This commit is contained in:
@@ -73,7 +73,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
def run(self):
|
||||
if self._model_type == "albert":
|
||||
try:
|
||||
from transformers.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -82,7 +82,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "bert":
|
||||
try:
|
||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -91,7 +91,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "funnel":
|
||||
try:
|
||||
from transformers.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -99,14 +99,14 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
convert_openai_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "transfo_xl":
|
||||
try:
|
||||
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
convert_transfo_xl_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -123,7 +123,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
elif self._model_type == "gpt2":
|
||||
try:
|
||||
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
convert_gpt2_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -132,7 +132,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
try:
|
||||
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
convert_xlnet_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -142,13 +142,13 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||
)
|
||||
elif self._model_type == "xlm":
|
||||
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_xlm_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||
elif self._model_type == "lxmert":
|
||||
from transformers.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_lxmert_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user