Fix Torch.hub + Integration test
This commit is contained in:
32
.github/workflows/github-torch-hub.yml
vendored
Normal file
32
.github/workflows/github-torch-hub.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
name: Torch hub integration
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- "*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
torch_hub_integration:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
# no checkout necessary here.
|
||||||
|
- name: Extract branch name
|
||||||
|
run: echo "::set-env name=BRANCH::${GITHUB_REF#refs/heads/}"
|
||||||
|
- name: Check branch name
|
||||||
|
run: echo $BRANCH
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: 3.7
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install torch
|
||||||
|
pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses
|
||||||
|
|
||||||
|
- name: Torch hub list
|
||||||
|
run: |
|
||||||
|
python -c "import torch; print(torch.hub.list('huggingface/transformers:$BRANCH'))"
|
||||||
|
|
||||||
|
- name: Torch hub help
|
||||||
|
run: |
|
||||||
|
python -c "import torch; print(torch.hub.help('huggingface/transformers:$BRANCH', 'modelForSequenceClassification'))"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from transformers import (
|
from src.transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
@@ -6,10 +6,10 @@ from transformers import (
|
|||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import add_start_docstrings
|
from src.transformers.file_utils import add_start_docstrings
|
||||||
|
|
||||||
|
|
||||||
dependencies = ["torch", "tqdm", "boto3", "requests", "regex", "sentencepiece", "sacremoses"]
|
dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoConfig.__doc__)
|
@add_start_docstrings(AutoConfig.__doc__)
|
||||||
|
|||||||
@@ -22,11 +22,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from transformers.configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from transformers.modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
|
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
||||||
|
from .modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -4,10 +4,9 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from transformers import ElectraConfig, add_start_docstrings
|
from .activations import get_activation
|
||||||
from transformers.activations import get_activation
|
from .configuration_electra import ElectraConfig
|
||||||
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .file_utils import add_start_docstrings_to_callable
|
|
||||||
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
|
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from .tokenization_xlnet import SPIECE_UNDERLINE
|
from .tokenization_xlnet import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ import os
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from .tokenization_xlnet import SPIECE_UNDERLINE
|
from .tokenization_xlnet import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user