PyTorch >= 1.7.0 and TensorFlow >= 2.4.0 (#19016)
This commit is contained in:
4
setup.py
4
setup.py
@@ -155,13 +155,13 @@ _deps = [
|
|||||||
"librosa",
|
"librosa",
|
||||||
"starlette",
|
"starlette",
|
||||||
"tensorflow-cpu>=2.3",
|
"tensorflow-cpu>=2.3",
|
||||||
"tensorflow>=2.3",
|
"tensorflow>=2.4",
|
||||||
"tensorflow-text",
|
"tensorflow-text",
|
||||||
"tf2onnx",
|
"tf2onnx",
|
||||||
"timeout-decorator",
|
"timeout-decorator",
|
||||||
"timm",
|
"timm",
|
||||||
"tokenizers>=0.11.1,!=0.11.3,<0.13",
|
"tokenizers>=0.11.1,!=0.11.3,<0.13",
|
||||||
"torch>=1.0,!=0.12.0",
|
"torch>=1.7,!=1.12.0",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"pyctcdecode>=0.3.0",
|
"pyctcdecode>=0.3.0",
|
||||||
"tqdm>=4.27",
|
"tqdm>=4.27",
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, use_gelu_python: bool = False):
|
def __init__(self, use_gelu_python: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.4") or use_gelu_python:
|
if use_gelu_python:
|
||||||
self.act = self._gelu_python
|
self.act = self._gelu_python
|
||||||
else:
|
else:
|
||||||
self.act = nn.functional.gelu
|
self.act = nn.functional.gelu
|
||||||
@@ -108,18 +108,8 @@ class SiLUActivation(nn.Module):
|
|||||||
later.
|
later.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
|
|
||||||
self.act = self._silu_python
|
|
||||||
else:
|
|
||||||
self.act = nn.functional.silu
|
|
||||||
|
|
||||||
def _silu_python(self, input: Tensor) -> Tensor:
|
|
||||||
return input * torch.sigmoid(input)
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return self.act(input)
|
return nn.functional.silu(input)
|
||||||
|
|
||||||
|
|
||||||
class MishActivation(nn.Module):
|
class MishActivation(nn.Module):
|
||||||
|
|||||||
@@ -61,13 +61,13 @@ deps = {
|
|||||||
"librosa": "librosa",
|
"librosa": "librosa",
|
||||||
"starlette": "starlette",
|
"starlette": "starlette",
|
||||||
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
||||||
"tensorflow": "tensorflow>=2.3",
|
"tensorflow": "tensorflow>=2.4",
|
||||||
"tensorflow-text": "tensorflow-text",
|
"tensorflow-text": "tensorflow-text",
|
||||||
"tf2onnx": "tf2onnx",
|
"tf2onnx": "tf2onnx",
|
||||||
"timeout-decorator": "timeout-decorator",
|
"timeout-decorator": "timeout-decorator",
|
||||||
"timm": "timm",
|
"timm": "timm",
|
||||||
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
|
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
|
||||||
"torch": "torch>=1.0,!=0.12.0",
|
"torch": "torch>=1.7,!=1.12.0",
|
||||||
"torchaudio": "torchaudio",
|
"torchaudio": "torchaudio",
|
||||||
"pyctcdecode": "pyctcdecode>=0.3.0",
|
"pyctcdecode": "pyctcdecode>=0.3.0",
|
||||||
"tqdm": "tqdm>=4.27",
|
"tqdm": "tqdm>=4.27",
|
||||||
|
|||||||
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -216,11 +211,8 @@ class AlbertEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||||
|
|||||||
@@ -40,12 +40,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -199,11 +194,8 @@ class BertEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -259,11 +259,8 @@ class BigBirdEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
# End copy
|
# End copy
|
||||||
|
|
||||||
|
|||||||
@@ -35,12 +35,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_convbert import ConvBertConfig
|
from .configuration_convbert import ConvBertConfig
|
||||||
|
|
||||||
@@ -198,11 +193,8 @@ class ConvBertEmbeddings(nn.Module):
|
|||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -87,11 +82,8 @@ class Data2VecTextForTextEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# End copy
|
# End copy
|
||||||
|
|||||||
@@ -22,15 +22,12 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
Conv1D,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_or_equal_than_1_6,
|
|
||||||
prune_conv1d_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -38,15 +35,6 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_1_6:
|
|
||||||
is_amp_available = True
|
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
else:
|
|
||||||
is_amp_available = False
|
|
||||||
|
|
||||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
||||||
from .configuration_decision_transformer import DecisionTransformerConfig
|
from .configuration_decision_transformer import DecisionTransformerConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -235,15 +223,10 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
scale_factor /= float(self.layer_idx + 1)
|
scale_factor /= float(self.layer_idx + 1)
|
||||||
|
|
||||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||||
if is_amp_available:
|
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||||
else:
|
|
||||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
|
||||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
|
||||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
|
||||||
|
|
||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
|
|||||||
@@ -39,12 +39,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -106,7 +101,6 @@ class Embeddings(nn.Module):
|
|||||||
|
|
||||||
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
|
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,12 +36,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -169,11 +164,8 @@ class ElectraEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||||
|
|||||||
@@ -38,12 +38,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -96,11 +91,8 @@ class ErnieEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...pytorch_utils import is_torch_greater_than_1_6
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from ..xlm.modeling_xlm import (
|
from ..xlm.modeling_xlm import (
|
||||||
XLMForMultipleChoice,
|
XLMForMultipleChoice,
|
||||||
@@ -139,7 +138,6 @@ class FlaubertModel(XLMModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
||||||
self.pre_norm = getattr(config, "pre_norm", False)
|
self.pre_norm = getattr(config, "pre_norm", False)
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from transformers.utils.doc import add_code_sample_docstrings
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...pytorch_utils import is_torch_greater_than_1_6
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -392,11 +391,8 @@ class FlavaTextEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
|
from ...pytorch_utils import apply_chunking_to_forward
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -117,11 +117,8 @@ class FNetEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||||
|
|||||||
@@ -23,22 +23,9 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...pytorch_utils import (
|
|
||||||
Conv1D,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_or_equal_than_1_6,
|
|
||||||
prune_conv1d_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_1_6:
|
|
||||||
is_amp_available = True
|
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
else:
|
|
||||||
is_amp_available = False
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -47,6 +34,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -247,15 +235,10 @@ class GPT2Attention(nn.Module):
|
|||||||
scale_factor /= float(self.layer_idx + 1)
|
scale_factor /= float(self.layer_idx + 1)
|
||||||
|
|
||||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||||
if is_amp_available:
|
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||||
else:
|
|
||||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
|
||||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
|
||||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
|
||||||
|
|
||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
|
|||||||
@@ -22,22 +22,9 @@ from typing import Any, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...pytorch_utils import (
|
|
||||||
Conv1D,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_or_equal_than_1_6,
|
|
||||||
prune_conv1d_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_1_6:
|
|
||||||
is_amp_available = True
|
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
else:
|
|
||||||
is_amp_available = False
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -45,6 +32,7 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_imagegpt import ImageGPTConfig
|
from .configuration_imagegpt import ImageGPTConfig
|
||||||
|
|
||||||
@@ -299,15 +287,10 @@ class ImageGPTAttention(nn.Module):
|
|||||||
scale_factor /= float(self.layer_idx + 1)
|
scale_factor /= float(self.layer_idx + 1)
|
||||||
|
|
||||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||||
if is_amp_available:
|
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||||
else:
|
|
||||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
|
||||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
|
||||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
|
||||||
|
|
||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from ...modeling_utils import (
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
from ...pytorch_utils import is_torch_greater_than_1_6
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_mctct import MCTCTConfig
|
from .configuration_mctct import MCTCTConfig
|
||||||
|
|
||||||
@@ -153,7 +152,6 @@ class MCTCTEmbeddings(nn.Module):
|
|||||||
|
|
||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||||
|
|||||||
@@ -38,12 +38,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
@@ -187,11 +182,8 @@ class NezhaEmbeddings(nn.Module):
|
|||||||
# any TensorFlow checkpoint file
|
# any TensorFlow checkpoint file
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False
|
||||||
torch.zeros((1, config.max_position_embeddings), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -33,12 +33,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_nystromformer import NystromformerConfig
|
from .configuration_nystromformer import NystromformerConfig
|
||||||
|
|
||||||
@@ -72,7 +67,6 @@ class NystromformerEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_than_1_6, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -166,11 +166,8 @@ class QDQBertEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -31,12 +31,7 @@ from ...modeling_outputs import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_realm import RealmConfig
|
from .configuration_realm import RealmConfig
|
||||||
|
|
||||||
@@ -185,11 +180,8 @@ class RealmEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -35,12 +35,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -87,11 +82,8 @@ class RobertaEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# End copy
|
# End copy
|
||||||
|
|||||||
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_or_equal_than_1_10, prune_linear_layer
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_or_equal_than_1_10,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_vilt import ViltConfig
|
from .configuration_vilt import ViltConfig
|
||||||
|
|
||||||
@@ -255,11 +250,8 @@ class TextEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||||
|
|||||||
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -80,11 +75,8 @@ class XLMRobertaXLEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long),
|
|
||||||
persistent=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# End copy
|
# End copy
|
||||||
|
|||||||
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import (
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
apply_chunking_to_forward,
|
|
||||||
find_pruneable_heads_and_indices,
|
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
prune_linear_layer,
|
|
||||||
)
|
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_yoso import YosoConfig
|
from .configuration_yoso import YosoConfig
|
||||||
|
|
||||||
@@ -261,7 +256,6 @@ class YosoEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||||
is_torch_greater_or_equal_than_1_6 = parsed_torch_version_base >= version.parse("1.6.0")
|
|
||||||
is_torch_greater_than_1_6 = parsed_torch_version_base > version.parse("1.6.0")
|
|
||||||
is_torch_less_than_1_8 = parsed_torch_version_base < version.parse("1.8.0")
|
is_torch_less_than_1_8 = parsed_torch_version_base < version.parse("1.8.0")
|
||||||
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
|
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
|
||||||
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
|
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
|
||||||
|
|||||||
@@ -71,12 +71,7 @@ from .modelcard import TrainingSummary
|
|||||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||||
from .optimization import Adafactor, get_scheduler
|
from .optimization import Adafactor, get_scheduler
|
||||||
from .pytorch_utils import (
|
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
|
||||||
ALL_LAYERNORM_LAYERS,
|
|
||||||
is_torch_greater_or_equal_than_1_6,
|
|
||||||
is_torch_greater_or_equal_than_1_10,
|
|
||||||
is_torch_less_than_1_11,
|
|
||||||
)
|
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
@@ -155,9 +150,7 @@ from .utils import (
|
|||||||
from .utils.generic import ContextManagers
|
from .utils.generic import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
_is_torch_generator_available = False
|
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
|
||||||
_is_native_cuda_amp_available = False
|
|
||||||
_is_native_cpu_amp_available = False
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||||
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
||||||
@@ -170,13 +163,6 @@ if is_in_notebook():
|
|||||||
if is_apex_available():
|
if is_apex_available():
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_1_6:
|
|
||||||
_is_torch_generator_available = True
|
|
||||||
_is_native_cuda_amp_available = True
|
|
||||||
|
|
||||||
if is_torch_greater_or_equal_than_1_10:
|
|
||||||
_is_native_cpu_amp_available = True
|
|
||||||
|
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
@@ -565,12 +551,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Tried to use cpu amp but native cpu amp is not available")
|
raise ValueError("Tried to use cpu amp but native cpu amp is not available")
|
||||||
else:
|
else:
|
||||||
if _is_native_cuda_amp_available:
|
|
||||||
args.half_precision_backend = "cuda_amp"
|
args.half_precision_backend = "cuda_amp"
|
||||||
elif args.bf16:
|
|
||||||
raise ValueError("Tried to use `bf16` but native amp is not available")
|
|
||||||
else:
|
|
||||||
args.half_precision_backend = "apex"
|
|
||||||
|
|
||||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||||
|
|
||||||
@@ -781,7 +762,7 @@ class Trainer:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
generator = None
|
generator = None
|
||||||
if self.args.world_size <= 1 and _is_torch_generator_available:
|
if self.args.world_size <= 1:
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
|
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
|
||||||
# `args.seed`) if data_seed isn't provided.
|
# `args.seed`) if data_seed isn't provided.
|
||||||
@@ -826,9 +807,7 @@ class Trainer:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
if _is_torch_generator_available:
|
|
||||||
return RandomSampler(self.train_dataset, generator=generator)
|
return RandomSampler(self.train_dataset, generator=generator)
|
||||||
return RandomSampler(self.train_dataset)
|
|
||||||
elif (
|
elif (
|
||||||
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
||||||
and not self.args.dataloader_drop_last
|
and not self.args.dataloader_drop_last
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from typing import Any, Dict, Iterator, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from packaging import version
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
@@ -831,12 +830,7 @@ def _get_learning_rate(self):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
last_lr = (
|
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
# backward compatibility for pytorch schedulers
|
|
||||||
self.lr_scheduler.get_last_lr()[0]
|
|
||||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.4")
|
|
||||||
else self.lr_scheduler.get_lr()[0]
|
|
||||||
)
|
|
||||||
if torch.is_tensor(last_lr):
|
if torch.is_tensor(last_lr):
|
||||||
last_lr = last_lr.item()
|
last_lr = last_lr.item()
|
||||||
return last_lr
|
return last_lr
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ from ...pytorch_utils import (
|
|||||||
apply_chunking_to_forward,
|
apply_chunking_to_forward,
|
||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
is_torch_greater_than_1_6,
|
|
||||||
)
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||||
@@ -157,7 +156,6 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
if is_torch_greater_than_1_6:
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||||
|
|||||||
Reference in New Issue
Block a user