Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)
* support passing flash_attn_kwargs when gradient_checkpointing is enabled * make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_dummy.py file directly. One of our CI enforces this.
|
# modular_dummy.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -544,7 +545,7 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_multimodal1.py file directly. One of our CI enforces this.
|
# modular_multimodal1.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -544,7 +545,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -963,7 +964,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@
|
|||||||
# This file is based on the LLama model definition file in transformers
|
# This file is based on the LLama model definition file in transformers
|
||||||
|
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -613,7 +614,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -634,7 +635,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -533,7 +534,7 @@ class Cohere2Model(Gemma2Model):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
# modular_deepseek_v3.py file directly. One of our CI enforces this.
|
# modular_deepseek_v3.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -759,7 +760,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -852,7 +853,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import cached_property
|
from functools import cached_property, partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -1439,7 +1439,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -645,7 +646,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -491,7 +492,7 @@ class Gemma2Model(GemmaModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -732,7 +733,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
position_embeddings_global,
|
position_embeddings_global,
|
||||||
position_embeddings_local,
|
position_embeddings_local,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -662,7 +663,7 @@ class Gemma3TextModel(Gemma2Model):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
position_embeddings_global,
|
position_embeddings_global,
|
||||||
position_embeddings_local,
|
position_embeddings_local,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -594,7 +595,7 @@ class GlmModel(GlmPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -593,7 +594,7 @@ class GraniteModel(GranitePreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -185,7 +186,7 @@ class GraniteModel(LlamaModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -581,7 +582,7 @@ class HeliumModel(HeliumPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -583,7 +584,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_mistral.py file directly. One of our CI enforces this.
|
# modular_mistral.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -548,7 +549,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -672,7 +673,7 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Mixtral model."""
|
"""PyTorch Mixtral model."""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -400,7 +401,7 @@ class MixtralModel(MistralModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -936,7 +937,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -832,7 +833,7 @@ class MoonshineDecoder(LlamaModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_olmo.py file directly. One of our CI enforces this.
|
# modular_olmo.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -559,7 +560,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_olmo2.py file directly. One of our CI enforces this.
|
# modular_olmo2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -560,7 +561,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_phi.py file directly. One of our CI enforces this.
|
# modular_phi.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -553,7 +554,7 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -243,7 +244,7 @@ class PhiModel(LlamaModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -623,7 +624,7 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_qwen2.py file directly. One of our CI enforces this.
|
# modular_qwen2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -561,7 +562,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user