[FlaxBert] Add ForCausalLM (#16995)
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
This commit is contained in:
@@ -19,7 +19,7 @@ import numpy as np
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, token_type_ids, attention_mask = config_and_inputs
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
|
||||
if is_flax_available():
|
||||
import jax
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import (
|
||||
FlaxBigBirdForCausalLM,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxBigBirdForCausalLM,
|
||||
FlaxBigBirdModel,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.electra.modeling_flax_electra import (
|
||||
FlaxElectraForCausalLM,
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForMultipleChoice,
|
||||
FlaxElectraForPreTraining,
|
||||
@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxElectraModel,
|
||||
FlaxElectraForCausalLM,
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForPreTraining,
|
||||
FlaxElectraForTokenClassification,
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_flax_available():
|
||||
AutoTokenizer,
|
||||
EncoderDecoderConfig,
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertForCausalLM,
|
||||
FlaxBertModel,
|
||||
FlaxEncoderDecoderModel,
|
||||
FlaxGPT2LMHeadModel,
|
||||
@@ -545,6 +546,43 @@ class FlaxBartEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBertEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxBertModel(config)
|
||||
decoder_model = FlaxBertForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxEncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
|
||||
@@ -19,11 +19,12 @@ import numpy as np
|
||||
from transformers import RobertaConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
@@ -112,6 +113,22 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, token_type_ids, attention_mask = config_and_inputs
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
@@ -121,6 +138,7 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
|
||||
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
|
||||
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
|
||||
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
|
||||
from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
|
||||
@@ -34,6 +35,7 @@ if is_flax_available():
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import (
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertForCausalLM,
|
||||
FlaxGPT2LMHeadModel,
|
||||
FlaxSpeechEncoderDecoderModel,
|
||||
FlaxWav2Vec2Model,
|
||||
@@ -807,3 +809,118 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/wav2vec2-large-lv60", "bert-large-uncased"
|
||||
)
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxWav2Vec2Model(config)
|
||||
decoder_model = FlaxBertForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, inputs, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_flaxwav2vec2bert_pt_flax_equivalence(self):
|
||||
pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
|
||||
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs_dict = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
Reference in New Issue
Block a user