Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -33,6 +33,7 @@ if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from flax.training.common_utils import onehot
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from transformers import (
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertForCausalLM,
|
||||
@@ -73,7 +74,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
@@ -103,7 +104,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -142,7 +143,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
return_dict,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||
@@ -169,7 +170,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
||||
@@ -208,7 +209,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
# assert that loading encoder and decoder models from configs has been correctly executed
|
||||
@@ -253,7 +254,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
@@ -336,7 +337,7 @@ class FlaxEncoderDecoderMixin:
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
@@ -406,7 +407,6 @@ class FlaxEncoderDecoderMixin:
|
||||
self.assertTrue((grad == grad_frozen).all())
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
@@ -448,7 +448,6 @@ class FlaxEncoderDecoderMixin:
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
@@ -460,7 +459,6 @@ class FlaxEncoderDecoderMixin:
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = SpeechEncoderDecoderModel(encoder_decoder_config)
|
||||
@@ -508,7 +506,6 @@ class FlaxEncoderDecoderMixin:
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
@@ -62,7 +62,7 @@ class EncoderDecoderMixin:
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
@@ -95,7 +95,7 @@ class EncoderDecoderMixin:
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -135,7 +135,7 @@ class EncoderDecoderMixin:
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_values if input_features is None else input_features
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
@@ -173,7 +173,7 @@ class EncoderDecoderMixin:
|
||||
return_dict,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||
@@ -202,7 +202,7 @@ class EncoderDecoderMixin:
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -245,7 +245,7 @@ class EncoderDecoderMixin:
|
||||
decoder_attention_mask,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -292,7 +292,7 @@ class EncoderDecoderMixin:
|
||||
labels=None,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
|
||||
Reference in New Issue
Block a user