fix typos in the code comments and error messages (#36993)
* chore: enhance code comments * chore: enhance code comments * chore: enhance code comments * chore: enhance code comments * chore: enhance code comments * chore: enhance code comments * chore: enhance code comments
This commit is contained in:
@@ -798,7 +798,7 @@ class AltCLIPAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
|
|||||||
output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
|
output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
|
||||||
|
|
||||||
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
||||||
# Insert zero at the begining for offset index's convenience
|
# Insert zero at the beginning for offset index's convenience
|
||||||
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
||||||
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
|
|||||||
output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
|
output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
|
||||||
|
|
||||||
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
||||||
# Insert zero at the begining for offset index's convenience
|
# Insert zero at the beginning for offset index's convenience
|
||||||
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
||||||
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
||||||
|
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ class CLIPAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ class CLIPSegAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
@@ -1016,7 +1016,7 @@ class EsmFoldSelfAttention(nn.Module):
|
|||||||
use mask.
|
use mask.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
|
x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
|
||||||
x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
|
x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
|
|||||||
@@ -989,10 +989,10 @@ class Rigid:
|
|||||||
|
|
||||||
def to_tensor_4x4(self) -> torch.Tensor:
|
def to_tensor_4x4(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Converts a transformation to a homogenous transformation tensor.
|
Converts a transformation to a homogeneous transformation tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A [*, 4, 4] homogenous transformation tensor
|
A [*, 4, 4] homogeneous transformation tensor
|
||||||
"""
|
"""
|
||||||
tensor = self._trans.new_zeros((*self.shape, 4, 4))
|
tensor = self._trans.new_zeros((*self.shape, 4, 4))
|
||||||
tensor[..., :3, :3] = self._rots.get_rot_mats()
|
tensor[..., :3, :3] = self._rots.get_rot_mats()
|
||||||
@@ -1003,10 +1003,10 @@ class Rigid:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_tensor_4x4(t: torch.Tensor) -> Rigid:
|
def from_tensor_4x4(t: torch.Tensor) -> Rigid:
|
||||||
"""
|
"""
|
||||||
Constructs a transformation from a homogenous transformation tensor.
|
Constructs a transformation from a homogeneous transformation tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
t: [*, 4, 4] homogenous transformation tensor
|
t: [*, 4, 4] homogeneous transformation tensor
|
||||||
Returns:
|
Returns:
|
||||||
T object with shape [*]
|
T object with shape [*]
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class FalconMambaConfig(PretrainedConfig):
|
|||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the cache should be used.
|
Whether or not the cache should be used.
|
||||||
use_mambapy (`bool`, *optional*, defaults to `False`):
|
use_mambapy (`bool`, *optional*, defaults to `False`):
|
||||||
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not avaiable. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
||||||
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
|
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
|
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
|
||||||
Example:
|
Example:
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class FalconMambaMixer(nn.Module):
|
|||||||
|
|
||||||
# projection of the input hidden states
|
# projection of the input hidden states
|
||||||
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
|
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
|
||||||
# selective projection used to make dt, B and C input dependant
|
# selective projection used to make dt, B and C input dependent
|
||||||
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
|
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
|
||||||
# time step projection (discretization)
|
# time step projection (discretization)
|
||||||
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||||
@@ -768,7 +768,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
|||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Overwitten -- uses `cache_params` as opposed to `past_key_values`
|
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
# `cache_position` should have been initialized in `generate`
|
# `cache_position` should have been initialized in `generate`
|
||||||
|
|||||||
@@ -791,7 +791,7 @@ class GitVisionAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class IdeficsVisionAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
@@ -543,7 +543,7 @@ class Kosmos2VisionAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ class LongT5LayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class MambaMixer(nn.Module):
|
|||||||
|
|
||||||
# projection of the input hidden states
|
# projection of the input hidden states
|
||||||
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
|
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
|
||||||
# selective projection used to make dt, B and C input dependant
|
# selective projection used to make dt, B and C input dependent
|
||||||
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
|
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
|
||||||
# time step projection (discretization)
|
# time step projection (discretization)
|
||||||
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||||
@@ -708,7 +708,7 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
|
|||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Overwitten -- uses `cache_params` as opposed to `past_key_values`
|
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
# `cache_position` should have been initialized in `generate`
|
# `cache_position` should have been initialized in `generate`
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class MT5LayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# MT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# MT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class Pix2StructLayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class Pop2PianoLayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ GIN_TO_CONFIG_MAPPING = {
|
|||||||
|
|
||||||
|
|
||||||
def convert_gin_to_config(gin_file, num_experts):
|
def convert_gin_to_config(gin_file, num_experts):
|
||||||
# Convert a google style config to the hugging face fromat
|
# Convert a google style config to the hugging face format
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
with open(gin_file, "r") as f:
|
with open(gin_file, "r") as f:
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ class SwitchTransformersLayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# SwitchTransformers uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# SwitchTransformers uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
@@ -297,12 +297,12 @@ class SwitchTransformersSparseMLP(nn.Module):
|
|||||||
expert the corresponding hidden states.
|
expert the corresponding hidden states.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Step 1: Get the router_mask from the router as wel as the probabilities
|
# Step 1: Get the router_mask from the router as well as the probabilities
|
||||||
router_mask, router_probs, router_logits = self.router(hidden_states)
|
router_mask, router_probs, router_logits = self.router(hidden_states)
|
||||||
expert_index = torch.argmax(router_mask, dim=-1)
|
expert_index = torch.argmax(router_mask, dim=-1)
|
||||||
|
|
||||||
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
|
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
|
||||||
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
|
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones.
|
||||||
|
|
||||||
next_states = hidden_states.clone()
|
next_states = hidden_states.clone()
|
||||||
|
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_f
|
|||||||
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
|
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
|
||||||
|
|
||||||
flax_model.save_pretrained(flax_dump_folder_path)
|
flax_model.save_pretrained(flax_dump_folder_path)
|
||||||
print("T5X Model was sucessfully converted!")
|
print("T5X Model was successfully converted!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ class T5LayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -524,7 +524,7 @@ class UdopLayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class UMT5LayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# UMT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# UMT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ def transform_attention(current: np.ndarray):
|
|||||||
return transform_attention_kernel(current)
|
return transform_attention_kernel(current)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid number of dimesions: {np.ndim(current)}")
|
raise Exception(f"Invalid number of dimensions: {np.ndim(current)}")
|
||||||
|
|
||||||
|
|
||||||
def transform_attention_bias(current: np.ndarray):
|
def transform_attention_bias(current: np.ndarray):
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ class XCLIPAttention(nn.Module):
|
|||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit awkward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
|
|||||||
Reference in New Issue
Block a user