remove blank line (+1 squashed commit) Squashed commits: [24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits) Squashed commits: [08bd27e7a] [run-slow]vit_msn,vision_encoder_decoder [ec96a8db3] [run-slow]vit_msn [ead817eca] fix vit msn multi gpu [d12cdc8fd] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos [3fdbfa88f] doc [a3ff33e4a] finish implementation [e20b7b7fb] Update test_modeling_common.py [e290c5810] Update test_modeling_flax_common.py [d3af86f46] comment [ff7dd32d8] more comments [59b137889] suggestion [7e2ba6d67] attn_implementation as attribute of the class [fe66ab71f] minor [38642b568] Apply suggestions from code review Accept comments Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [22cde7d52] Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [48e137cc6] Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [99f4c679f] Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [96cf20a6d] Update src/transformers/models/vit_msn/modeling_vit_msn.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [c59377d23] Update src/transformers/models/vit_mae/modeling_vit_mae.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [b70a47259] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> [00c84d216] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos [61f00ebb0] all tests are passing locally [e9e0b82b7] vision encoder/decoder [4d5076b56] test-vision (+20 squashed commits) Squashed commits: [d1add8db9] yolo [9fde65716] fix flax [986566c28] minor [ca2f21d1f] vit [3333efd7a] easy models change [ebfc21402] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos [b8b8603ed] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos [48ecc7e26] all tests are passing locally [bff7fc366] minor [62f88306f] fix yolo and text_encoder tests [121507555] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae [1064cae0a] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos [b7f52ff3a] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae [cffaa10dd] fix-copies [ef6c511c4] test vit hybrid [7d4ba8644] vit hybrid [66f919033] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae [1fcc0a031] fixes [cfde6eb21] fixup [e77df1ed3] all except yolo end encoder decoder (+17 squashed commits) Squashed commits: [602913e22] vit + vit_mae are working [547f6c4cc] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/ passes [61a97dfa9] it s the complete opposite... [aefab37d4] fix more tests [71802a1b9] fix all torch tests [40b12eb58] encoder - decoder tests [941552b69] slow decorator where appropriate [14d055d80] has_attentions to yolo and msn [3381fa19f] add correct name [e261316a7] repo consistency [31c6d0c08] fixup [9d214276c] minor fix [11ed2e1b7] chore [eca6644c4] add sdpa to vit-based models [cffbf390b] make fix-copies result [6468319b0] fix style [d324cd02a] add sdpa for vit Co-authored-by: Liubov Yaronskaya <luba.yaronskaya@gmail.com>
This commit is contained in:
@@ -63,6 +63,7 @@ class ASTModelTester:
|
||||
scope=None,
|
||||
frequency_stride=2,
|
||||
time_stride=2,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -83,6 +84,7 @@ class ASTModelTester:
|
||||
self.scope = scope
|
||||
self.frequency_stride = frequency_stride
|
||||
self.time_stride = time_stride
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
|
||||
frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1
|
||||
@@ -117,6 +119,7 @@ class ASTModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
frequency_stride=self.frequency_stride,
|
||||
time_stride=self.time_stride,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, labels):
|
||||
|
||||
@@ -80,6 +80,8 @@ class DeiTModelTester:
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
encoder_stride=2,
|
||||
mask_ratio=0.5,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -99,10 +101,14 @@ class DeiTModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.encoder_stride = encoder_stride
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 2
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_masks = int(mask_ratio * self.seq_length)
|
||||
self.mask_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -130,6 +136,7 @@ class DeiTModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -71,6 +71,7 @@ class TFDeiTModelTester:
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
encoder_stride=2,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -90,6 +91,7 @@ class TFDeiTModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.encoder_stride = encoder_stride
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
@@ -121,6 +123,7 @@ class TFDeiTModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -70,6 +70,7 @@ class VideoMAEModelTester:
|
||||
initializer_range=0.02,
|
||||
mask_ratio=0.9,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -91,6 +92,7 @@ class VideoMAEModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.mask_ratio = mask_ratio
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame
|
||||
self.num_patches_per_frame = (image_size // patch_size) ** 2
|
||||
@@ -132,6 +134,7 @@ class VideoMAEModelTester:
|
||||
decoder_intermediate_size=self.intermediate_size,
|
||||
decoder_num_attention_heads=self.num_attention_heads,
|
||||
decoder_num_hidden_layers=self.num_hidden_layers,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -197,7 +200,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
# hence we define a single mask, which we then repeat for each example in the batch
|
||||
mask = torch.ones((self.model_tester.num_masks,))
|
||||
mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))])
|
||||
bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool()
|
||||
batch_size = inputs_dict["pixel_values"].shape[0]
|
||||
bool_masked_pos = mask.expand(batch_size, -1).bool()
|
||||
inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device)
|
||||
|
||||
if return_labels:
|
||||
|
||||
@@ -492,7 +492,9 @@ class TFVisionEncoderDecoderMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
||||
pt_model = VisionEncoderDecoderModel.from_pretrained(
|
||||
tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation
|
||||
)
|
||||
|
||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class FlaxViTModelTester(unittest.TestCase):
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -66,6 +67,7 @@ class FlaxViTModelTester(unittest.TestCase):
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
@@ -87,6 +89,7 @@ class FlaxViTModelTester(unittest.TestCase):
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
@@ -63,6 +63,7 @@ class TFViTModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -81,6 +82,7 @@ class TFViTModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
@@ -111,6 +113,7 @@ class TFViTModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -68,6 +68,8 @@ class ViTModelTester:
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
encoder_stride=2,
|
||||
mask_ratio=0.5,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -87,10 +89,14 @@ class ViTModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.encoder_stride = encoder_stride
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_masks = int(mask_ratio * self.seq_length)
|
||||
self.mask_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -118,6 +124,7 @@ class ViTModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -58,6 +58,7 @@ class ViTHybridModelTester:
|
||||
initializer_range=0.02,
|
||||
backbone_featmap_shape=[1, 16, 4, 4],
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -77,6 +78,7 @@ class ViTHybridModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.backbone_featmap_shape = backbone_featmap_shape
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
# the number of patches is based on the feature map of the backbone, which by default uses an output stride
|
||||
@@ -122,6 +124,7 @@ class ViTHybridModelTester:
|
||||
backbone_featmap_shape=self.backbone_featmap_shape,
|
||||
backbone_config=backbone_config,
|
||||
backbone=None,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -72,6 +72,7 @@ class TFViTMAEModelTester:
|
||||
num_labels=3,
|
||||
mask_ratio=0.6,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -91,6 +92,7 @@ class TFViTMAEModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.mask_ratio = mask_ratio
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
||||
# (we add 1 for the [CLS] token)
|
||||
@@ -127,6 +129,7 @@ class TFViTMAEModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
mask_ratio=self.mask_ratio,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -63,8 +63,9 @@ class ViTMAEModelTester:
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
mask_ratio=0.6,
|
||||
scope=None,
|
||||
mask_ratio=0.5,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -84,11 +85,15 @@ class ViTMAEModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.mask_ratio = mask_ratio
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
||||
# (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_masks = int(mask_ratio * self.seq_length)
|
||||
self.mask_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -120,6 +125,7 @@ class ViTMAEModelTester:
|
||||
decoder_intermediate_size=self.intermediate_size,
|
||||
decoder_num_attention_heads=self.num_attention_heads,
|
||||
decoder_num_hidden_layers=self.num_hidden_layers,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -59,6 +59,7 @@ class ViTMSNModelTester:
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -77,6 +78,7 @@ class ViTMSNModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
# in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
@@ -106,6 +108,7 @@ class ViTMSNModelTester:
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
initializer_range=self.initializer_range,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -62,6 +62,7 @@ class YolosModelTester:
|
||||
scope=None,
|
||||
n_targets=8,
|
||||
num_detection_tokens=10,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -83,6 +84,7 @@ class YolosModelTester:
|
||||
self.scope = scope
|
||||
self.n_targets = n_targets
|
||||
self.num_detection_tokens = num_detection_tokens
|
||||
self.attn_implementation = attn_implementation
|
||||
# we set the expected sequence length (which is used in several tests)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
|
||||
num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
|
||||
@@ -123,6 +125,7 @@ class YolosModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
num_detection_tokens=self.num_detection_tokens,
|
||||
num_labels=self.num_labels,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -2788,7 +2788,9 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
pt_model_loaded = model_class.from_pretrained(
|
||||
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
|
||||
)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
@@ -3724,6 +3726,11 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
|
||||
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
|
||||
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
|
||||
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
|
||||
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
|
||||
|
||||
is_encoder_decoder = model.config.is_encoder_decoder
|
||||
|
||||
@@ -3861,6 +3868,27 @@ class ModelTesterMixin:
|
||||
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
|
||||
):
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
if not deactivate_mask and (
|
||||
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
|
||||
):
|
||||
dummy_mask = torch.ones((self.model_tester.num_masks,))
|
||||
|
||||
# In case of additional token (like class) we define a custom `mask_length`
|
||||
if hasattr(self.model_tester, "mask_length"):
|
||||
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
|
||||
else:
|
||||
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
|
||||
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
|
||||
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
|
||||
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
|
||||
|
||||
if "noise" in inspect.signature(model_eager.forward).parameters:
|
||||
np.random.seed(2)
|
||||
num_patches = int(
|
||||
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
|
||||
)
|
||||
noise = np.random.uniform(size=(batch_size, num_patches))
|
||||
processed_inputs["noise"] = torch.from_numpy(noise)
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -371,7 +371,9 @@ class FlaxModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(
|
||||
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
|
||||
)
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user